Yihua7 commited on
Commit
6b92ff7
·
0 Parent(s):

Initial commit: AniGen - Animatable 3D Generation

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +58 -0
  2. Dockerfile +63 -0
  3. README.md +231 -0
  4. THIRD_PARTY_LICENSES.md +30 -0
  5. anigen/__init__.py +6 -0
  6. anigen/datasets/__init__.py +32 -0
  7. anigen/datasets/anigen_sparse_feat2skeleton.py +290 -0
  8. anigen/datasets/anigen_sparse_structure.py +124 -0
  9. anigen/datasets/anigen_sparse_structure_latent.py +238 -0
  10. anigen/datasets/anigen_structured_latent.py +327 -0
  11. anigen/datasets/components.py +143 -0
  12. anigen/models/__init__.py +67 -0
  13. anigen/models/anigen_sparse_structure_flow.py +487 -0
  14. anigen/models/anigen_sparse_structure_vae.py +729 -0
  15. anigen/models/anigen_structured_latent_flow.py +553 -0
  16. anigen/models/sparse_elastic_mixin.py +24 -0
  17. anigen/models/structured_latent_vae/__init__.py +3 -0
  18. anigen/models/structured_latent_vae/anigen_base.py +256 -0
  19. anigen/models/structured_latent_vae/anigen_decoder.py +834 -0
  20. anigen/models/structured_latent_vae/anigen_encoder.py +318 -0
  21. anigen/models/structured_latent_vae/base.py +117 -0
  22. anigen/models/structured_latent_vae/skin_models.py +252 -0
  23. anigen/modules/attention/__init__.py +36 -0
  24. anigen/modules/attention/full_attn.py +140 -0
  25. anigen/modules/attention/modules.py +161 -0
  26. anigen/modules/norm.py +25 -0
  27. anigen/modules/sparse/__init__.py +102 -0
  28. anigen/modules/sparse/attention/__init__.py +5 -0
  29. anigen/modules/sparse/attention/full_attn.py +215 -0
  30. anigen/modules/sparse/attention/modules.py +151 -0
  31. anigen/modules/sparse/attention/serialized_attn.py +193 -0
  32. anigen/modules/sparse/attention/windowed_attn.py +135 -0
  33. anigen/modules/sparse/attention/windowed_attn_cross.py +131 -0
  34. anigen/modules/sparse/basic.py +465 -0
  35. anigen/modules/sparse/conv/__init__.py +21 -0
  36. anigen/modules/sparse/conv/conv_spconv.py +80 -0
  37. anigen/modules/sparse/conv/conv_torchsparse.py +38 -0
  38. anigen/modules/sparse/linear.py +15 -0
  39. anigen/modules/sparse/nonlinearity.py +35 -0
  40. anigen/modules/sparse/norm.py +58 -0
  41. anigen/modules/sparse/spatial.py +110 -0
  42. anigen/modules/sparse/transformer/__init__.py +3 -0
  43. anigen/modules/sparse/transformer/anigen_modulated.py +155 -0
  44. anigen/modules/sparse/transformer/blocks.py +259 -0
  45. anigen/modules/sparse/transformer/modulated.py +174 -0
  46. anigen/modules/spatial.py +48 -0
  47. anigen/modules/transformer/__init__.py +2 -0
  48. anigen/modules/transformer/blocks.py +285 -0
  49. anigen/modules/transformer/modulated.py +175 -0
  50. anigen/modules/utils.py +54 -0
.gitattributes ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ anigen/representations/mesh/flexicubes/images/block_init.png filter=lfs diff=lfs merge=lfs -text
37
+ anigen/representations/mesh/flexicubes/images/teaser_top.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/cond_images/dog.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/cond_images/lamp.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/cond_images/machine_arm.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/cond_images/owl.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/cond_images/trex.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/cond_images/whale.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/images/teaser.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/cond_images/machine_dog.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/cond_images/spongebob.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/gifs/eagle.gif filter=lfs diff=lfs merge=lfs -text
48
+ assets/gifs/evo.gif filter=lfs diff=lfs merge=lfs -text
49
+ assets/gifs/horse.gif filter=lfs diff=lfs merge=lfs -text
50
+ assets/gifs/iron_boy.gif filter=lfs diff=lfs merge=lfs -text
51
+ assets/gifs/machine_arm.gif filter=lfs diff=lfs merge=lfs -text
52
+ assets/gifs/machine_dog.gif filter=lfs diff=lfs merge=lfs -text
53
+ assets/gifs/mairo.gif filter=lfs diff=lfs merge=lfs -text
54
+ assets/gifs/money_tree.gif filter=lfs diff=lfs merge=lfs -text
55
+ assets/cond_images/brickbob.png filter=lfs diff=lfs merge=lfs -text
56
+ assets/cond_images/bruno_star.png filter=lfs diff=lfs merge=lfs -text
57
+ assets/cond_images/evo.png filter=lfs diff=lfs merge=lfs -text
58
+ assets/cond_images/iron_boy.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ python3.10 python3-pip python3.10-dev \
7
+ git git-lfs ffmpeg libsm6 libxext6 libgl1 libegl1 \
8
+ build-essential ninja-build cmake rsync \
9
+ && rm -rf /var/lib/apt/lists/* \
10
+ && git lfs install
11
+
12
+ RUN useradd -m -u 1000 user
13
+ USER user
14
+
15
+ ENV HOME=/home/user \
16
+ PATH=/home/user/.local/bin:$PATH \
17
+ PYTHONUNBUFFERED=1 \
18
+ PIP_NO_CACHE_DIR=1 \
19
+ HF_HOME=/home/user/.cache/huggingface \
20
+ TORCH_HOME=/home/user/.cache/torch \
21
+ ATTN_BACKEND=xformers \
22
+ SPARSE_ATTN_BACKEND=xformers \
23
+ TORCH_CUDA_ARCH_LIST="7.5;8.6;8.9"
24
+
25
+ WORKDIR $HOME/app
26
+
27
+ COPY --chown=user:user . $HOME/app
28
+
29
+ RUN python3.10 -m pip install --upgrade pip setuptools wheel
30
+
31
+ RUN python3.10 -m pip install \
32
+ torch==2.4.0 torchvision==0.19.0 \
33
+ --index-url https://download.pytorch.org/whl/cu121
34
+
35
+ RUN python3.10 -m pip install \
36
+ pillow imageio imageio-ffmpeg tqdm easydict scipy ninja psutil safetensors \
37
+ scikit-learn opencv-python-headless rembg onnxruntime \
38
+ trimesh xatlas pyvista pymeshfix igraph pygltflib geffnet \
39
+ transformers \
40
+ gradio==4.44.1 gradio_litmodel3d==0.0.1 "huggingface_hub<0.25"
41
+
42
+ RUN python3.10 -m pip install \
43
+ git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
44
+
45
+ RUN python3.10 -m pip install \
46
+ "pytorch3d==0.7.8" \
47
+ --find-links https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html
48
+
49
+ RUN python3.10 -m pip install \
50
+ xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121
51
+
52
+ RUN python3.10 -m pip install spconv-cu121
53
+
54
+ RUN python3.10 -m pip install \
55
+ kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html
56
+
57
+ RUN python3.10 -m pip install \
58
+ "git+https://github.com/NVlabs/nvdiffrast.git" --no-build-isolation
59
+
60
+ EXPOSE 7860
61
+
62
+ CMD ["python3.10", "app.py"]
63
+
README.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AniGen
3
+ sdk: gradio
4
+ sdk_version: 4.44.1
5
+ python_version: 3.10.13
6
+ startup_duration_timeout: 2h
7
+ ---
8
+
9
+ <h1 align="center">AniGen: Unified S<sup>3</sup> Fields for Animatable 3D Asset Generation</h1>
10
+ <p align="center"><a href="https://arxiv.org/pdf/2604.08746"><img src='https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white' alt='arXiv'></a>
11
+ <a href='https://yihua7.github.io/AniGen_web/'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
12
+ <a href='https://huggingface.co/spaces/VAST-AI/AniGen'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Live_Demo-blue'></a>
13
+ <a href="https://www.tripo3d.ai"><img src="https://img.shields.io/badge/Tripo-AI_3D_Workspace-orange" alt="Tripo"></a>
14
+ </p>
15
+ <p align="center"><img src="assets/images/teaser.png" width="100%"></p>
16
+
17
+ <span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> is a unified framework that directly generates animate-ready 3D assets conditioned on a single image. Our key insight is to represent shape, skeleton, and skinning as mutually consistent *$S^3$ Fields* (Shape, Skeleton, Skin) defined over a shared spatial domain. To enable the robust learning of these fields, we introduce two technical innovations: (i) a *confidence-decaying skeleton field* that explicitly handles the geometric ambiguity of bone prediction at Voronoi boundaries, and (ii) a *dual skin feature field* that decouples skinning weights from specific joint counts, allowing a fixed-architecture network to predict rigs of arbitrary complexity. Built upon a two-stage flow-matching pipeline, <span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> first synthesizes a sparse structural scaffold and then generates dense geometry and articulation in a structured latent space. Extensive experiments demonstrate that <span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> substantially outperforms state-of-the-art sequential baselines in rig validity and animation quality, generalizing effectively to in-the-wild images across diverse categories including animals, humanoids, and machinery.
18
+
19
+ <!-- Overview -->
20
+ ## 🔮 Overview
21
+
22
+ AniGen takes a **single image** as input and automatically produces a fully rigged, animate-ready 3D asset, complete with a coherent mesh, an articulated skeleton, and smooth skinning weights. The generated assets can be directly imported into standard 3D pipelines and driven by off-the-shelf motion data, enabling immediate deployment across a wide spectrum of downstream applications, including **embodied AI** agent construction, **physics-based simulation**, **character animation**, **dynamic scene creation**, and **articulated object manipulation**.
23
+
24
+ <table width="100%">
25
+ <tr>
26
+ <td width="25%" align="center"><img src="assets/gifs/machine_arm.gif" width="100%"><br><b>Machine Arm</b></td>
27
+ <td width="25%" align="center"><img src="assets/gifs/machine_dog.gif" width="100%"><br><b>Machine Dog</b></td>
28
+ <td width="25%" align="center"><img src="assets/gifs/money_tree.gif" width="100%"><br><b>Money Tree</b></td>
29
+ <td width="25%" align="center"><img src="assets/gifs/iron_boy.gif" width="100%"><br><b>Iron Boy</b></td>
30
+ </tr>
31
+ <tr>
32
+ <td width="25%" align="center"><img src="assets/gifs/mairo.gif" width="100%"><br><b>Mairo</b></td>
33
+ <td width="25%" align="center"><img src="assets/gifs/evo.gif" width="100%"><br><b>Evo</b></td>
34
+ <td width="25%" align="center"><img src="assets/gifs/horse.gif" width="100%"><br><b>Horse</b></td>
35
+ <td width="25%" align="center"><img src="assets/gifs/eagle.gif" width="100%"><br><b>Eagle</b></td>
36
+ </tr>
37
+ </table>
38
+
39
+ <!-- Installation -->
40
+ ## 📦 Installation
41
+
42
+ ### Prerequisites
43
+ - **System**: The code is currently tested only on **Linux**.
44
+ - **Hardware**: An NVIDIA GPU with at least 18GB of memory is necessary. The code has been verified on NVIDIA A800 and RTX3090 GPUs.
45
+ - **Software**:
46
+ - The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain submodules. The code has been tested with CUDA versions 11.8 and 12.2.
47
+ - [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is recommended for managing dependencies.
48
+ - Python version 3.8 or higher is required.
49
+
50
+ ### Installation Steps
51
+ 1. Clone the repo:
52
+ ```sh
53
+ git clone --recurse-submodules https://github.com/VAST-AI-Research/AniGen.git
54
+ cd AniGen
55
+ ```
56
+
57
+ 2. Install the dependencies:
58
+
59
+ We recommend using [uv](https://docs.astral.sh/uv/) for fast, reliable installs. The setup script will also work with plain `pip` if `uv` is not available.
60
+
61
+ Create a new virtual environment and install everything:
62
+ ```sh
63
+ source ./setup.sh --new-env --all
64
+ ```
65
+
66
+ If your network connection to PyPI is unstable or slow, you can use the Tsinghua mirror:
67
+ ```sh
68
+ source ./setup.sh --new-env --all --tsinghua
69
+ ```
70
+
71
+ If you already have an environment with PyTorch installed, install into it directly:
72
+ ```sh
73
+ source ./setup.sh --basic
74
+ ```
75
+
76
+ > [!NOTE]
77
+ > The setup script auto-detects your CUDA version and installs matching wheels for PyTorch, spconv, pytorch3d, and nvdiffrast. [DSINE](https://github.com/baegwangbin/DSINE) (used for surface normal estimation) is loaded at runtime via `torch.hub` and does not require separate installation.
78
+
79
+
80
+ <!-- Pretrained Models -->
81
+ ## 🤖 Pretrained Models
82
+
83
+ We provide the following pretrained models on [Hugging Face](https://huggingface.co/VAST-AI/AniGen/tree/main). Please make sure to download all necessary weights from this page, including the required dinov2, dsine, and vgg checkpoints.
84
+
85
+ > [!TIP]
86
+ > **Recommended:** Use **SS-Flow-Duet** + **SLAT-Flow-Auto** if you do not have specific requirements.
87
+ > - For more detailed skeleton (including character fingers) → **SS-Flow-Duet**
88
+ > - For better geometry generalization → **SS-Flow-Solo**
89
+ > - **SLAT-Flow-Control** supports density levels 0–4, but if the density condition significantly deviates from the proper value for the object, skinning weights may be damaged.
90
+
91
+ | DAE Model | Description | Download |
92
+ | --- | --- | --- |
93
+ | SS-DAE | Encoder&Decoder of SS | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_dae) |
94
+ | SLAT-DAE | Encoder&Decoder of SLAT | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_dae) |
95
+
96
+ | SS Model | Description | Download |
97
+ | --- | --- | --- |
98
+ | SS-Flow-Duet | Detailed Skeleton (Full-FT Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_duet) |
99
+ | SS-Flow-Epic | Geometry&Skeleton Balanced (LoRA-FT Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_epic) |
100
+ | SS-Flow-Solo | Accurate Geometry (Freeze Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_solo) |
101
+
102
+ | SLAT Model | Description | Download |
103
+ | --- | --- | --- |
104
+ | SLAT-Flow-Auto | Automatically Determine Joint Number | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_flow_auto) |
105
+ | SLAT-Flow-Control | Controllable Joint Density | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_flow_control) |
106
+
107
+ <!-- Usage -->
108
+ ## 💡 Usage
109
+
110
+ ### Minimal Example
111
+
112
+ Here is an [example](example.py) of how to use the pretrained models for 3D asset generation.
113
+
114
+
115
+ After running the code, you will get the following files:
116
+ - `mesh.glb`: a rigged mesh file
117
+ - `skeleton.glb`: a skeleton visualization file
118
+ - `processed_image.png`: the masked image as the condition
119
+
120
+ ### AniGen Pipeline (Rigged Mesh + Skeleton)
121
+
122
+ For AniGen checkpoints in this repo (e.g. `ckpts/anigen/ss_flow_solo` + `ckpts/anigen/slat_flow_control`), you can run:
123
+ ```sh
124
+ python example.py --image_path assets/cond_images/trex.png
125
+ ```
126
+
127
+ ### Web Demo
128
+
129
+ [app.py](app.py) provides a simple web demo for 3D asset generation. Since this demo is based on [Gradio](https://gradio.app/), additional dependencies are required:
130
+ ```sh
131
+ source ./setup.sh --demo
132
+ ```
133
+
134
+ If needed, you can also install the demo dependencies via the Tsinghua mirror:
135
+ ```sh
136
+ source ./setup.sh --demo --tsinghua
137
+ ```
138
+
139
+ After installing the dependencies, you can run the demo with the following command:
140
+ ```sh
141
+ python app.py
142
+ ```
143
+
144
+ Then, you can access the demo at the address shown in the terminal.
145
+
146
+ ***The web demo is also available on [Hugging Face Spaces](https://huggingface.co/spaces/VAST-AI/AniGen)!***
147
+
148
+
149
+ <!-- Training -->
150
+ ## 🏋️ Training
151
+
152
+ ### Training Data
153
+
154
+ Sample training data is available at [AniGen_sample_data](https://huggingface.co/datasets/VAST-AI/AniGen_sample_data). To prepare your own data, refer to [TRELLIS](https://github.com/microsoft/TRELLIS) and the sample data format.
155
+
156
+ ### Prerequisites
157
+
158
+ > [!NOTE]
159
+ > Training requires the **CUBVH** extension (`extensions/CUBVH/`), which is automatically built by `setup.sh`. It is **not** needed for inference (`app.py`, `example.py`).
160
+
161
+ ### Training Commands
162
+
163
+ The pipeline has five stages. Later stages depend on earlier ones, so please train in order:
164
+
165
+ ```sh
166
+ # Stage 1: Skin AutoEncoder
167
+ python train.py --config configs/anigen_skin_ae.json --output_dir outputs/anigen_skin_ae
168
+
169
+ # Stage 2: Sparse Structure DAE
170
+ python train.py --config configs/ss_dae.json --output_dir outputs/ss_dae
171
+
172
+ # Stage 3: Structured Latent DAE
173
+ python train.py --config configs/slat_dae.json --output_dir outputs/slat_dae
174
+
175
+ # Stage 4: SS Flow Matching (image-conditioned generation)
176
+ python train.py --config configs/ss_flow_duet.json --output_dir outputs/ss_flow_duet
177
+
178
+ # Stage 5: SLAT Flow Matching (image-conditioned generation)
179
+ python train.py --config configs/slat_flow_auto.json --output_dir outputs/slat_flow_auto
180
+ ```
181
+
182
+ ### Multi-Node / Multi-GPU
183
+
184
+ Append the following flags for distributed training across multiple machines and GPUs:
185
+
186
+ ```sh
187
+ python train.py --config configs/<config>.json --output_dir outputs/<output> \
188
+ --num_nodes XX --node_rank XX --master_addr XX --master_port XX
189
+ ```
190
+
191
+ ### Model Variants
192
+
193
+ Other SS Flow variants (`ss_flow_epic`, `ss_flow_solo`) and SLAT Flow variants (`slat_flow_control`, `slat_flow_gsn_auto`) are available under `ckpts/anigen/`. Their config files can be found at `ckpts/anigen/<variant>/config.json`.
194
+
195
+ ### Resume / Restart
196
+
197
+ Training automatically resumes from the latest checkpoint in `--output_dir`. To start fresh, pass `--ckpt none`.
198
+
199
+
200
+ ## License
201
+
202
+ This project's source code is released under the [MIT License](LICENSE).
203
+
204
+ > [!IMPORTANT]
205
+ > This repository includes third-party components with additional license restrictions. In particular, `extensions/CUBVH/` contains BVH code derived from NVIDIA's [instant-ngp](https://github.com/NVlabs/instant-ngp), which is licensed for **non-commercial / research use only**. See [THIRD_PARTY_LICENSES.md](THIRD_PARTY_LICENSES.md) for details.
206
+
207
+ ## Acknowledgements
208
+
209
+ - [TRELLIS](https://github.com/microsoft/TRELLIS) by Microsoft
210
+ - [cuBVH](https://github.com/ashawkey/cubvh) by Jiaxiang Tang
211
+ - [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and [instant-ngp](https://github.com/NVlabs/instant-ngp) by Thomas Müller / NVIDIA
212
+ - [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) by NVIDIA
213
+
214
+ We sincerely appreciate the contributions of these excellent projects and their authors. We believe open source helps accelerate research, lower barriers to innovation, and make progress more accessible to the broader community.
215
+
216
+
217
+ <!-- Citation -->
218
+ ## 📜 Citation
219
+
220
+ If you find this work helpful, please consider citing our paper:
221
+
222
+ ```bibtex
223
+ @article{huang2026anigen,
224
+ title = {AniGen: Unified $S^3$ Fields for Animatable 3D Asset Generation},
225
+ author = {Huang, Yi-Hua and Zhou, Zi-Xin and He, Yuting and Chang, Chirui
226
+ and Pu, Cheng-Feng and Yang, Ziyi and Guo, Yuan-Chen
227
+ and Cao, Yan-Pei and Qi, Xiaojuan},
228
+ journal = {ACM SIGGRAPH},
229
+ year = {2026}
230
+ }
231
+ ```
THIRD_PARTY_LICENSES.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Third-Party Licenses
2
+
3
+ ## extensions/CUBVH/ — cuBVH (CUDA Mesh BVH Acceleration)
4
+
5
+ Originally created by [Jiaxiang Tang (ashawkey)](https://github.com/ashawkey/cubvh),
6
+ modified by Yi-Hua Huang (yihua7).
7
+
8
+ ### MIT License (cubvh overall)
9
+ - File: `extensions/CUBVH/LICENSE`
10
+ - Copyright (c) 2022 Jiaxiang Tang (ashawkey)
11
+ - Copyright (c) 2025 Yi-Hua Huang (yihua7)
12
+
13
+ ### NVIDIA Source Code License — Non-Commercial (BVH from instant-ngp)
14
+ - File: `extensions/CUBVH/LICENSE_NVIDIA`
15
+ - Copyright (c) 2022, NVIDIA Corporation & affiliates
16
+ - **USE RESTRICTED TO NON-COMMERCIAL / RESEARCH PURPOSES ONLY**
17
+
18
+ ### BSD 3-Clause License (gpu_memory.h from tiny-cuda-nn)
19
+ - File header: `extensions/CUBVH/include/gpu/gpu_memory.h`
20
+ - Copyright (c) 2020-2022, NVIDIA CORPORATION
21
+
22
+ ### Apache License 2.0 (pcg32.h)
23
+ - File header: `extensions/CUBVH/include/gpu/pcg32.h`
24
+ - Author: Wenzel Jakob, modified by tiny-cuda-nn
25
+
26
+ ## anigen/representations/mesh/flexicubes/ — FlexiCubes
27
+
28
+ - File: `anigen/representations/mesh/flexicubes/LICENSE.txt`
29
+ - Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES
30
+ - Apache License 2.0
anigen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import models
2
+ from . import modules
3
+ from . import pipelines
4
+ from . import renderers
5
+ from . import representations
6
+ from . import utils
anigen/datasets/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'AniGenSparseStructure': 'anigen_sparse_structure',
5
+ 'AniGenSparseFeat2Skeleton': 'anigen_sparse_feat2skeleton',
6
+ 'AniGenSparseFeat2Render': 'anigen_sparse_feat2render',
7
+
8
+ 'AniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
9
+ 'TextConditionedAniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
10
+ 'ImageConditionedAniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
11
+
12
+ 'AniGenSLat': 'anigen_structured_latent',
13
+ 'AniGenTextConditionedSLat': 'anigen_structured_latent',
14
+ 'AniGenImageConditionedSLat': 'anigen_structured_latent',
15
+ }
16
+
17
+ __submodules = []
18
+
19
+ __all__ = list(__attributes.keys()) + __submodules
20
+
21
+ def __getattr__(name):
22
+ if name not in globals():
23
+ if name in __attributes:
24
+ module_name = __attributes[name]
25
+ module = importlib.import_module(f".{module_name}", __name__)
26
+ globals()[name] = getattr(module, name)
27
+ elif name in __submodules:
28
+ module = importlib.import_module(f".{name}", __name__)
29
+ globals()[name] = module
30
+ else:
31
+ raise AttributeError(f"module {__name__} has no attribute {name}")
32
+ return globals()[name]
anigen/datasets/anigen_sparse_feat2skeleton.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import utils3d.torch
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .components import StandardDatasetBase
10
+
11
+
12
+ class AniGenSparseFeat2Skeleton(StandardDatasetBase):
13
+ """
14
+ SparseFeat2Render dataset.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ image_size (int): size of the image
19
+ model (str): model name
20
+ resolution (int): resolution of the data
21
+ min_aesthetic_score (float): minimum aesthetic score
22
+ max_num_voxels (int): maximum number of voxels
23
+ """
24
+ def __init__(
25
+ self,
26
+ roots: str,
27
+ image_size: int,
28
+ model: str = 'dinov2_vitl14_reg',
29
+ resolution: int = 64,
30
+ min_aesthetic_score: float = 5.0,
31
+ max_num_voxels: int = 32768,
32
+ load_cubvh: bool = False,
33
+ skl_dilation_iter: int = 0,
34
+ skl_dilation_random_aug: bool = False,
35
+ skl_dilation_random_aug_prob: float = 0.5,
36
+ filter_bad_skin: bool = False,
37
+
38
+ test_mode: bool = True, # Test the model performance
39
+ is_test: bool = False, # Train or validation
40
+ skin_accum_as_flow: bool = False, # Accumulate skin weights from bottom to top as flow-by probability
41
+ local_rank: int = 0,
42
+ joint_merge_res: int = 64,
43
+ **kwargs,
44
+ ):
45
+ self.image_size = image_size
46
+ self.model = model
47
+ self.resolution = resolution
48
+ self.min_aesthetic_score = min_aesthetic_score
49
+ self.max_num_voxels = max_num_voxels
50
+ self.value_range = (0, 1)
51
+ self.load_cubvh = load_cubvh
52
+ self.skl_dilation_iter = skl_dilation_iter
53
+ self.skl_dilation_random_aug = skl_dilation_random_aug
54
+ self.skl_dilation_random_aug_prob = skl_dilation_random_aug_prob
55
+ self.filter_bad_skin = filter_bad_skin
56
+
57
+ self.test_mode = test_mode
58
+ self.is_test = is_test
59
+ self.skin_accum_as_flow = skin_accum_as_flow
60
+ self.local_rank = local_rank
61
+ self.joint_merge_res = joint_merge_res
62
+
63
+ super().__init__(roots, **kwargs)
64
+ self.is_bad_skin_list = self.metadata['is_bad_skin'].values
65
+
66
+ def filter_metadata(self, metadata):
67
+ stats = {}
68
+ metadata = metadata[metadata[f'feature_{self.model}']]
69
+ stats['With features'] = len(metadata)
70
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
71
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
72
+ metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
73
+ stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
74
+
75
+ if 'is_bad_skeleton' in metadata.columns:
76
+ metadata = metadata[~metadata['is_bad_skeleton']]
77
+ if self.filter_bad_skin and 'is_bad_skin' in metadata.columns:
78
+ metadata = metadata[~metadata['is_bad_skin']]
79
+
80
+ if self.test_mode:
81
+ metadata = metadata[metadata['is_test']] if self.is_test else metadata[~metadata['is_test']]
82
+
83
+ return metadata, stats
84
+
85
+ def _get_image(self, root, instance):
86
+ with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
87
+ metadata = json.load(f)
88
+ n_views = len(metadata['frames'])
89
+ view = np.random.randint(n_views)
90
+ metadata = metadata['frames'][view]
91
+ fov = metadata['camera_angle_x']
92
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
93
+ c2w = torch.tensor(metadata['transform_matrix'])
94
+ c2w[:3, 1:3] *= -1
95
+ extrinsics = torch.inverse(c2w)
96
+
97
+ image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
98
+ image = Image.open(image_path)
99
+ alpha = image.getchannel(3)
100
+ image = image.convert('RGB')
101
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
102
+ alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
103
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
104
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
105
+
106
+ return {
107
+ 'image': image,
108
+ 'alpha': alpha,
109
+ 'extrinsics': extrinsics,
110
+ 'intrinsics': intrinsics,
111
+ }
112
+
113
+ def _get_feat(self, root, instance):
114
+ DATA_RESOLUTION = 64
115
+ feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
116
+ feats_data = np.load(feats_path, allow_pickle=True)
117
+ coords = torch.tensor(feats_data['indices']).int()
118
+ feats = torch.tensor(feats_data['patchtokens']).float()
119
+
120
+ position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}_skeleton.ply'))[0]
121
+ coords_skl = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
122
+ ss_skl = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
123
+ ss_skl[0, coords_skl[:,0], coords_skl[:,1], coords_skl[:,2]] = 1
124
+ ss_skl_ori = ss_skl.clone()
125
+ if self.skl_dilation_random_aug or self.skl_dilation_iter > 0:
126
+ size = max(0, self.skl_dilation_iter) * 2 + 1
127
+ if self.skl_dilation_iter > 0:
128
+ kernel = torch.ones(1, 1, size, size, size, dtype=torch.float32, device=ss_skl.device)
129
+ ss_skl = torch.nn.functional.conv3d(ss_skl.float()[None], kernel, padding=self.skl_dilation_iter)
130
+ ss_skl = (ss_skl > 0).long().squeeze(0)
131
+ coords_skl = torch.nonzero(ss_skl[0], as_tuple=False).int()
132
+ if self.skl_dilation_random_aug and np.random.rand() < self.skl_dilation_random_aug_prob:
133
+ size_small, size_large = size - 2, size + 2
134
+ kernel_large = torch.ones(1, 1, size_large, size_large, size_large, dtype=torch.float32, device=ss_skl_ori.device)
135
+ ss_skl_large = torch.nn.functional.conv3d(ss_skl_ori.float()[None], kernel_large, padding=size_large//2)
136
+ ss_skl_large = (ss_skl_large > 0).long().squeeze(0)
137
+ if size_small > 1:
138
+ kernel_small = torch.ones(1, 1, size_small, size_small, size_small, dtype=torch.float32, device=ss_skl_ori.device)
139
+ ss_skl_small = torch.nn.functional.conv3d(ss_skl_ori.float()[None], kernel_small, padding=size_small//2)
140
+ ss_skl_small = (ss_skl_small > 0).long().squeeze(0)
141
+ else:
142
+ ss_skl_small = torch.zeros_like(ss_skl)
143
+
144
+ ss_skl_random_mask = torch.rand_like(ss_skl.float()) < 0.5
145
+ ss_skl = ss_skl_small * ss_skl_random_mask.long() + ss_skl_large * (1 - ss_skl_random_mask.long())
146
+ coords_skl = torch.nonzero(ss_skl[0], as_tuple=False).int()
147
+ feats_skl = torch.zeros((coords_skl.shape[0], 0), dtype=torch.float32)
148
+
149
+ if self.resolution != DATA_RESOLUTION:
150
+ factor = DATA_RESOLUTION // self.resolution
151
+ coords = coords // factor
152
+ coords, idx = coords.unique(return_inverse=True, dim=0)
153
+ feats = torch.scatter_reduce(
154
+ torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
155
+ dim=0,
156
+ index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
157
+ src=feats,
158
+ reduce='mean'
159
+ )
160
+ coords_skl = coords_skl // factor
161
+ coords_skl, idx = coords_skl.unique(return_inverse=True, dim=0)
162
+ feats_skl = torch.scatter_reduce(
163
+ torch.zeros(coords_skl.shape[0], feats_skl.shape[1], device=feats_skl.device),
164
+ dim=0,
165
+ index=idx.unsqueeze(-1).expand(-1, feats_skl.shape[1]),
166
+ src=feats_skl,
167
+ reduce='mean'
168
+ )
169
+
170
+ return {
171
+ 'coords': coords,
172
+ 'feats': feats,
173
+ 'coords_skl': coords_skl,
174
+ 'feats_skl': feats_skl,
175
+ }
176
+
177
+ @torch.no_grad()
178
+ def visualize_sample(self, sample: dict):
179
+ return sample['image']
180
+
181
+ @staticmethod
182
+ def collate_fn(batch):
183
+ pack = {}
184
+ coords = []
185
+ coords_skl = []
186
+ for i, b in enumerate(batch):
187
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
188
+ coords_skl.append(torch.cat([torch.full((b['coords_skl'].shape[0], 1), i, dtype=torch.int32), b['coords_skl']], dim=-1))
189
+ coords = torch.cat(coords)
190
+ feats = torch.cat([b['feats'] for b in batch])
191
+ pack['feats'] = SparseTensor(
192
+ coords=coords,
193
+ feats=feats,
194
+ )
195
+ coords_skl = torch.cat(coords_skl)
196
+ feats_skl = torch.cat([b['feats_skl'] for b in batch])
197
+ pack['feats_skl'] = SparseTensor(
198
+ coords=coords_skl,
199
+ feats=feats_skl,
200
+ )
201
+
202
+ pack['image'] = torch.stack([b['image'] for b in batch])
203
+ pack['alpha'] = torch.stack([b['alpha'] for b in batch])
204
+ pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
205
+ pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
206
+
207
+ pack['joints'] = [b['joints'] for b in batch]
208
+ pack['parents'] = [b['parents'] for b in batch]
209
+ pack['skin'] = [b['skin'] for b in batch]
210
+ pack['is_bad_skin'] = [b['is_bad_skin'] for b in batch]
211
+
212
+ # collate other data
213
+ keys = [k for k in batch[0].keys() if k not in ['coords', 'feats', 'coords_skl', 'feats_skl', 'image', 'alpha', 'extrinsics', 'intrinsics', 'joints', 'parents', 'skin']]
214
+ for k in keys:
215
+ if isinstance(batch[0][k], torch.Tensor):
216
+ pack[k] = torch.stack([b[k] for b in batch])
217
+ elif isinstance(batch[0][k], list):
218
+ pack[k] = sum([b[k] for b in batch], [])
219
+ else:
220
+ pack[k] = [b[k] for b in batch]
221
+
222
+ return pack
223
+
224
+ def _get_geo(self, root, instance):
225
+ skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
226
+ skl_data = np.load(skeleton_path, allow_pickle=True)
227
+ verts, face = np.array(skl_data['vertices'], dtype=np.float32), skl_data['faces']
228
+ mesh = {
229
+ "vertices" : torch.from_numpy(verts),
230
+ "faces" : torch.from_numpy(face),
231
+ }
232
+ geo = {"mesh": mesh}
233
+ if self.load_cubvh:
234
+ from cubvh import cuBVH
235
+ torch.cuda.set_device(self.local_rank)
236
+ cubvh_path = os.path.join(root, 'skeleton', instance, 'cubvh.pth')
237
+ if os.path.exists(cubvh_path):
238
+ bvh = torch.load(cubvh_path, weights_only=False)
239
+ if isinstance(bvh, cuBVH):
240
+ bvh = bvh.to('cpu')
241
+ else:
242
+ device = torch.device(f"cuda:{self.local_rank}")
243
+ bvh = cuBVH(mesh["vertices"], mesh["faces"], device=device)
244
+ bvh = bvh.to('cpu')
245
+ torch.save(bvh, cubvh_path)
246
+ geo["cubvh"] = bvh
247
+ return geo
248
+
249
+ def _get_skeleton(self, root, instance):
250
+ skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
251
+ skl_data = np.load(skeleton_path, allow_pickle=True)
252
+ joints, parents, skin = skl_data['joints'], skl_data['parents'], skl_data['skin']
253
+ parents[parents==None] = -1
254
+ parents = np.array(parents, dtype=np.int32)
255
+
256
+ skin[np.where(skl_data['skin'].max(axis=1)==0)[0], 0] = 1.0
257
+ skin = skin / skin.sum(-1, keepdims=True)
258
+
259
+ if self.skin_accum_as_flow:
260
+ root_idx = np.where(parents == -1)[0][0]
261
+ def sum_children(joint_idx, skin_weights):
262
+ children = np.where(parents == joint_idx)[0]
263
+ for child in children:
264
+ skin_weights[:, joint_idx] += sum_children(child, skin_weights)
265
+ return skin_weights[:, joint_idx]
266
+ sum_children(root_idx, skin)
267
+ skin = np.clip(skin, 0, 1)
268
+
269
+ is_bad_skin = self.metadata['is_bad_skin'][instance]
270
+
271
+ return {
272
+ 'joints': torch.from_numpy(joints).float(),
273
+ 'parents': torch.from_numpy(parents).int(),
274
+ 'skin': torch.from_numpy(skin).float(),
275
+ 'is_bad_skin': is_bad_skin
276
+ }
277
+
278
+ def get_instance(self, root, instance):
279
+ image = self._get_image(root, instance)
280
+ feat = self._get_feat(root, instance)
281
+ geo = self._get_geo(root, instance)
282
+ skl = self._get_skeleton(root, instance)
283
+
284
+ return {
285
+ **image,
286
+ **feat,
287
+ **geo,
288
+ **skl,
289
+ 'instance': instance,
290
+ }
anigen/datasets/anigen_sparse_structure.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import utils3d
9
+ from .components import StandardDatasetBase
10
+ from ..representations.octree import DfsOctree as Octree
11
+ from ..renderers import OctreeRenderer
12
+
13
+
14
+ class AniGenSparseStructure(StandardDatasetBase):
15
+ """
16
+ Sparse structure dataset
17
+
18
+ Args:
19
+ roots (str): path to the dataset
20
+ resolution (int): resolution of the voxel grid
21
+ min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
22
+ """
23
+
24
+ def __init__(self,
25
+ roots,
26
+ resolution: int = 64,
27
+ min_aesthetic_score: float = 5.0,
28
+ skl_dilation_iter: int = 0,
29
+ **kwargs,
30
+ ):
31
+ self.resolution = resolution
32
+ self.min_aesthetic_score = min_aesthetic_score
33
+ self.skl_dilation_iter = skl_dilation_iter
34
+ self.value_range = (0, 1)
35
+
36
+ super().__init__(roots)
37
+
38
+ def filter_metadata(self, metadata):
39
+ stats = {}
40
+ metadata = metadata[metadata[f'voxelized']]
41
+ stats['Voxelized'] = len(metadata)
42
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
43
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
44
+ return metadata, stats
45
+
46
+ def get_ply_instance(self, root, instance, dilation_iter=None):
47
+ if dilation_iter is None:
48
+ dilation_iter = self.skl_dilation_iter
49
+
50
+ position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
51
+ coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
52
+ ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
53
+ ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
54
+ if dilation_iter > 0:
55
+ # 3D Dilation
56
+ size = dilation_iter * 2 + 1
57
+ kernel = torch.ones(1, 1, size, size, size, dtype=torch.float32, device=ss.device)
58
+ ss = torch.nn.functional.conv3d(ss.float()[None], kernel, padding=dilation_iter)
59
+ ss = (ss > 0).long().squeeze(0)
60
+ return ss
61
+
62
+ def get_instance(self, root, instance):
63
+ ss = self.get_ply_instance(root, instance, dilation_iter=0)
64
+ ss_skl = self.get_ply_instance(root, f'{instance}_skeleton', dilation_iter=self.skl_dilation_iter)
65
+ return {'ss': ss, 'ss_skl': ss_skl, 'instance': instance}
66
+
67
+ @torch.no_grad()
68
+ def visualize_sample(self, ss: Union[torch.Tensor, dict]):
69
+ ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
70
+
71
+ renderer = OctreeRenderer()
72
+ renderer.rendering_options.resolution = 512
73
+ renderer.rendering_options.near = 0.8
74
+ renderer.rendering_options.far = 1.6
75
+ renderer.rendering_options.bg_color = (0, 0, 0)
76
+ renderer.rendering_options.ssaa = 4
77
+ renderer.pipe.primitive = 'voxel'
78
+
79
+ # Build camera
80
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
81
+ yaws_offset = 0. # np.random.uniform(-np.pi / 4, np.pi / 4)
82
+ yaws = [y + yaws_offset for y in yaws]
83
+ pitch = np.linspace(-np.pi / 4, np.pi / 4, num=4) # [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
84
+
85
+ exts = []
86
+ ints = []
87
+ for yaw, pitch in zip(yaws, pitch):
88
+ orig = torch.tensor([
89
+ np.sin(yaw) * np.cos(pitch),
90
+ np.cos(yaw) * np.cos(pitch),
91
+ np.sin(pitch),
92
+ ]).float().cuda() * 2
93
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
94
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
95
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
96
+ exts.append(extrinsics)
97
+ ints.append(intrinsics)
98
+
99
+ images = []
100
+
101
+ # Build each representation
102
+ ss = ss.cuda()
103
+ for i in range(ss.shape[0]):
104
+ representation = Octree(
105
+ depth=10,
106
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
107
+ device='cuda',
108
+ primitive='voxel',
109
+ sh_degree=0,
110
+ primitive_config={'solid': True},
111
+ )
112
+ coords = torch.nonzero(ss[i, 0], as_tuple=False)
113
+ representation.position = coords.float() / self.resolution
114
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
115
+
116
+ image = torch.zeros(3, 1024, 1024).cuda()
117
+ tile = [2, 2]
118
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
119
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
120
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
121
+ images.append(image)
122
+
123
+ return torch.stack(images)
124
+
anigen/datasets/anigen_sparse_structure_latent.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d
7
+ from ..representations.octree import DfsOctree as Octree
8
+ from ..renderers import OctreeRenderer
9
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
10
+ from .. import models
11
+ from ..utils.dist_utils import read_file_dist
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class AniGenSparseStructureLatentVisMixin:
16
+ def __init__(
17
+ self,
18
+ *args,
19
+ pretrained_ss_dec: str = None,
20
+ ss_dec_path: Optional[str] = '',
21
+ ss_dec_ckpt: Optional[str] = 'final',
22
+ **kwargs
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.ss_dec = None
26
+ self.pretrained_ss_dec = pretrained_ss_dec
27
+ self.ss_dec_path = ss_dec_path
28
+ self.ss_dec_ckpt = ss_dec_ckpt
29
+
30
+ def _loading_ss_dec(self):
31
+ if self.ss_dec is not None:
32
+ return
33
+ if self.ss_dec_path is not None:
34
+ cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
35
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
36
+ ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
37
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
38
+ # decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True)) # Got stuck...
39
+ else:
40
+ decoder = models.from_pretrained(self.pretrained_ss_dec)
41
+ self.ss_dec = decoder.cuda().eval()
42
+
43
+ def _delete_ss_dec(self):
44
+ del self.ss_dec
45
+ self.ss_dec = None
46
+
47
+ @torch.no_grad()
48
+ def decode_latent(self, z, z_skl, batch_size=4):
49
+ self._loading_ss_dec()
50
+ ss = []
51
+ ss_skl = []
52
+ if self.normalization is not None:
53
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
54
+ if self.normalization_skl is not None:
55
+ z_skl = z_skl * self.std_skl.to(z_skl.device) + self.mean_skl.to(z_skl.device)
56
+ for i in range(0, z.shape[0], batch_size):
57
+ z_, z_skl_ = z[i:i+batch_size], z_skl[i:i+batch_size]
58
+ ss_, ss_skl_ = self.ss_dec(z_, z_skl_)
59
+ ss.append(ss_)
60
+ ss_skl.append(ss_skl_)
61
+ ss = torch.cat(ss, dim=0)
62
+ ss_skl = torch.cat(ss_skl, dim=0)
63
+ self._delete_ss_dec()
64
+ return ss, ss_skl
65
+
66
+ @torch.no_grad()
67
+ def visualize_sample(self, x_0: Union[torch.Tensor, dict], x_0_skl: Optional[Union[torch.Tensor, dict]]=None, **kwargs):
68
+
69
+ x_0_skl = x_0_skl if isinstance(x_0, torch.Tensor) else x_0['x_0_skl']
70
+ x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
71
+ x_0, x_0_skl = self.decode_latent(x_0.cuda(), x_0_skl.cuda())
72
+
73
+ renderer = OctreeRenderer()
74
+ renderer.rendering_options.resolution = 512
75
+ renderer.rendering_options.near = 0.8
76
+ renderer.rendering_options.far = 1.6
77
+ renderer.rendering_options.bg_color = (0, 0, 0)
78
+ renderer.rendering_options.ssaa = 4
79
+ renderer.pipe.primitive = 'voxel'
80
+
81
+ # Build camera
82
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
83
+ yaws_offset = 0 # np.random.uniform(-np.pi / 4, np.pi / 4)
84
+ yaws = [y + yaws_offset for y in yaws]
85
+ pitch = np.linspace(-np.pi / 4, np.pi / 4, 4) # [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
86
+
87
+ exts = []
88
+ ints = []
89
+ for yaw, pitch in zip(yaws, pitch):
90
+ orig = torch.tensor([
91
+ np.sin(yaw) * np.cos(pitch),
92
+ np.cos(yaw) * np.cos(pitch),
93
+ np.sin(pitch),
94
+ ]).float().cuda() * 2
95
+ fov = torch.deg2rad(torch.tensor(30)).cuda()
96
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
97
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
98
+ exts.append(extrinsics)
99
+ ints.append(intrinsics)
100
+
101
+ images = []
102
+ x_0 = x_0.cuda()
103
+ for i in range(x_0.shape[0]):
104
+ representation = Octree(
105
+ depth=10,
106
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
107
+ device='cuda',
108
+ primitive='voxel',
109
+ sh_degree=0,
110
+ primitive_config={'solid': True},
111
+ )
112
+ coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
113
+ resolution = x_0.shape[-1]
114
+ representation.position = coords.float() / resolution
115
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
116
+ image = torch.zeros(3, 1024, 1024).cuda()
117
+ tile = [2, 2]
118
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
119
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
120
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
121
+ images.append(image)
122
+
123
+ x_0_skl = x_0_skl.cuda()
124
+ for i in range(x_0_skl.shape[0]):
125
+ representation = Octree(
126
+ depth=10,
127
+ aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
128
+ device='cuda',
129
+ primitive='voxel',
130
+ sh_degree=0,
131
+ primitive_config={'solid': True},
132
+ )
133
+ coords = torch.nonzero(x_0_skl[i, 0] > 0, as_tuple=False)
134
+ resolution = x_0_skl.shape[-1]
135
+ representation.position = coords.float() / resolution
136
+ representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
137
+ image = torch.zeros(3, 1024, 1024).cuda()
138
+ tile = [2, 2]
139
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
140
+ res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
141
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
142
+ images[i] = torch.cat([images[i], image], dim=2)
143
+
144
+ return torch.stack(images)
145
+
146
+
147
+ class AniGenSparseStructureLatent(AniGenSparseStructureLatentVisMixin, StandardDatasetBase):
148
+ """
149
+ Sparse structure latent dataset
150
+
151
+ Args:
152
+ roots (str): path to the dataset
153
+ latent_model (str): name of the latent model
154
+ min_aesthetic_score (float): minimum aesthetic score
155
+ normalization (dict): normalization stats
156
+ pretrained_ss_dec (str): name of the pretrained sparse structure decoder
157
+ ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
158
+ ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
159
+ """
160
+ def __init__(self,
161
+ roots: str,
162
+ *,
163
+ latent_model: str,
164
+ min_aesthetic_score: float = 5.0,
165
+ normalization: Optional[dict] = None,
166
+ normalization_skl: Optional[dict] = None,
167
+ pretrained_ss_dec: str = None,
168
+ ss_dec_path: Optional[str] = '',
169
+ ss_dec_ckpt: Optional[str] = 'final',
170
+ **kwargs,
171
+ ):
172
+ self.latent_model = latent_model
173
+ self.min_aesthetic_score = min_aesthetic_score
174
+ self.normalization = normalization
175
+ self.normalization_skl = normalization_skl
176
+ self.value_range = (0, 1)
177
+
178
+ super().__init__(
179
+ roots,
180
+ pretrained_ss_dec=pretrained_ss_dec,
181
+ ss_dec_path=ss_dec_path,
182
+ ss_dec_ckpt=ss_dec_ckpt,
183
+ **kwargs,
184
+ )
185
+
186
+ if self.normalization is not None:
187
+ data = np.load(self.normalization)
188
+ self.mean = torch.tensor(data['feats_mean'])
189
+ self.std = torch.tensor(data['feats_std'])
190
+ if self.normalization_skl is not None:
191
+ data = np.load(self.normalization_skl)
192
+ self.mean_skl = torch.tensor(data['feats_skl_mean'])
193
+ self.std_skl = torch.tensor(data['feats_skl_std']).clip(min=1e-3)
194
+
195
+ def filter_metadata(self, metadata):
196
+ stats = {}
197
+ metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
198
+ stats['With sparse structure latents'] = len(metadata)
199
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
200
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
201
+
202
+ if 'is_bad_skeleton' in metadata.columns:
203
+ metadata = metadata[~metadata['is_bad_skeleton']]
204
+ if 'is_bad_skin' in metadata.columns:
205
+ metadata = metadata[~metadata['is_bad_skin']]
206
+
207
+ return metadata, stats
208
+
209
+ def get_instance(self, root, instance):
210
+ latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
211
+ z = torch.tensor(latent['mean']).float()
212
+ z_skl = torch.tensor(latent['mean_skl']).float()
213
+ if self.normalization is not None:
214
+ z = (z - self.mean) / self.std
215
+ if self.normalization_skl is not None:
216
+ z_skl = (z_skl - self.mean_skl) / self.std_skl
217
+
218
+ pack = {
219
+ 'instance': instance,
220
+ 'x_0': z,
221
+ 'x_0_skl': z_skl,
222
+ }
223
+ return pack
224
+
225
+
226
+ class TextConditionedAniGenSparseStructureLatent(TextConditionedMixin, AniGenSparseStructureLatent):
227
+ """
228
+ Text-conditioned sparse structure dataset
229
+ """
230
+ pass
231
+
232
+
233
+ class ImageConditionedAniGenSparseStructureLatent(ImageConditionedMixin, AniGenSparseStructureLatent):
234
+ """
235
+ Image-conditioned sparse structure dataset
236
+ """
237
+ pass
238
+
anigen/datasets/anigen_structured_latent.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import *
4
+ import numpy as np
5
+ import torch
6
+ import utils3d.torch
7
+ from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
8
+ from ..modules.sparse.basic import SparseTensor
9
+ from .. import models
10
+ from ..utils.render_utils import get_renderer
11
+ from ..utils.dist_utils import read_file_dist
12
+ from ..utils.data_utils import load_balanced_group_indices
13
+ import copy
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class AniGenSLatVisMixin:
18
+ def __init__(
19
+ self,
20
+ *args,
21
+ pretrained_slat_dec: str = None,
22
+ slat_dec_path: Optional[str] = None,
23
+ slat_dec_ckpt: Optional[str] = None,
24
+ load_cubvh: bool = False,
25
+ **kwargs
26
+ ):
27
+ super().__init__(*args, **kwargs)
28
+ self.slat_dec = None
29
+ self.pretrained_slat_dec = pretrained_slat_dec
30
+ self.slat_dec_path = slat_dec_path
31
+ self.slat_dec_ckpt = slat_dec_ckpt
32
+ self.load_cubvh = load_cubvh
33
+
34
+ def _loading_slat_dec(self):
35
+ if self.slat_dec is not None:
36
+ return
37
+ if self.slat_dec_path is not None:
38
+ cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
39
+ decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
40
+ ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
41
+ # decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
42
+ decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
43
+ else:
44
+ decoder = models.from_pretrained(self.pretrained_slat_dec)
45
+ self.slat_dec = decoder.cuda().eval()
46
+
47
+ def _delete_slat_dec(self):
48
+ del self.slat_dec
49
+ self.slat_dec = None
50
+
51
+ @torch.no_grad()
52
+ def decode_latent(self, z, z_skl, gt_joints=None, gt_parents=None, batch_size=4, gt_reps=None, gt_reps_skl=None):
53
+ self._loading_slat_dec()
54
+ reps = []
55
+ reps_skl = []
56
+ if gt_reps is not None:
57
+ skins_gt_ssskin = []
58
+ if gt_reps_skl is not None:
59
+ skins_gt_sklskin = []
60
+ if self.normalization is not None:
61
+ z = z * self.std.to(z.device) + self.mean.to(z.device)
62
+ z_skl = z_skl * self.std_skl.to(z.device) + self.mean_skl.to(z.device)
63
+ for i in range(0, z.shape[0], batch_size):
64
+ gt_j, gt_p = None if gt_joints is None else gt_joints[i:i+batch_size], None if gt_parents is None else gt_parents[i:i+batch_size]
65
+ z_, z_skl_ = z[i:i+batch_size], z_skl[i:i+batch_size]
66
+ rep, rep_skl = self.slat_dec(z_, z_skl_, gt_joints=gt_j, gt_parents=gt_p)
67
+ reps.append(rep)
68
+ reps_skl.append(rep_skl)
69
+ if gt_reps is not None:
70
+ skins_gt_ssskin.append(self.slat_dec.skinweight_forward(gt_reps[i:i+batch_size], rep_skl, gt_joints=gt_j, gt_parents=gt_p, return_skin_pred_only=True))
71
+ if gt_reps_skl is not None:
72
+ skins_gt_sklskin.append(self.slat_dec.skinweight_forward(rep, gt_reps_skl[i:i+batch_size], gt_joints=gt_j, gt_parents=gt_p, return_skin_pred_only=True))
73
+ reps = sum(reps, [])
74
+ reps_skl = sum(reps_skl, [])
75
+ self._delete_slat_dec()
76
+ to_return = (reps, reps_skl)
77
+ if gt_reps is not None:
78
+ skins_gt_ssskin = sum(skins_gt_ssskin, [])
79
+ to_return += (skins_gt_ssskin,)
80
+ if gt_reps_skl is not None:
81
+ skins_gt_sklskin = sum(skins_gt_sklskin, [])
82
+ to_return += (skins_gt_sklskin,)
83
+ return to_return
84
+
85
+ class AniGenSLat(AniGenSLatVisMixin, StandardDatasetBase):
86
+ """
87
+ structured latent dataset
88
+
89
+ Args:
90
+ roots (str): path to the dataset
91
+ latent_model (str): name of the latent model
92
+ min_aesthetic_score (float): minimum aesthetic score
93
+ max_num_voxels (int): maximum number of voxels
94
+ normalization (dict): normalization stats
95
+ pretrained_slat_dec (str): name of the pretrained slat decoder
96
+ slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
97
+ slat_dec_ckpt (str): name of the slat decoder checkpoint
98
+ """
99
+ def __init__(self,
100
+ roots: str,
101
+ *,
102
+ latent_model: str,
103
+ use_joint_num_cond: bool = False,
104
+ min_aesthetic_score: float = 5.0,
105
+ max_num_voxels: int = 32768,
106
+ normalization: Optional[dict] = None,
107
+ pretrained_slat_dec: str = None,
108
+ slat_dec_path: Optional[str] = None,
109
+ slat_dec_ckpt: Optional[str] = None,
110
+ local_rank: int = 0,
111
+ **kwargs,
112
+ ):
113
+ self.normalization = normalization
114
+ self.latent_model = latent_model
115
+ self.use_joint_num_cond = use_joint_num_cond
116
+ self.min_aesthetic_score = min_aesthetic_score
117
+ self.max_num_voxels = max_num_voxels
118
+ self.value_range = (0, 1)
119
+ self.local_rank = local_rank
120
+
121
+ super().__init__(
122
+ roots,
123
+ pretrained_slat_dec=pretrained_slat_dec,
124
+ slat_dec_path=slat_dec_path,
125
+ slat_dec_ckpt=slat_dec_ckpt,
126
+ **kwargs,
127
+ )
128
+
129
+ self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
130
+
131
+ if self.normalization is not None:
132
+ self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
133
+ self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
134
+ self.mean_skl = torch.tensor(self.normalization['mean_skl']).reshape(1, -1)
135
+ self.std_skl = torch.tensor(self.normalization['std_skl']).reshape(1, -1)
136
+
137
+ def filter_metadata(self, metadata):
138
+ stats = {}
139
+ metadata = metadata[metadata[f'latent_{self.latent_model}']]
140
+ stats['With latent'] = len(metadata)
141
+ metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
142
+ stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
143
+ # metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
144
+ # stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
145
+
146
+ if 'is_bad_skeleton' in metadata.columns:
147
+ metadata = metadata[~metadata['is_bad_skeleton']]
148
+ if 'is_bad_skin' in metadata.columns:
149
+ metadata = metadata[~metadata['is_bad_skin']]
150
+
151
+ return metadata, stats
152
+
153
+ @torch.no_grad()
154
+ def visualize_sample(self, data: dict):
155
+ return {}
156
+ x_0 = data['x_0']
157
+ x_0_skl = data['x_0_skl']
158
+ reps, reps_skl = self.decode_latent(x_0.cuda(), x_0_skl.cuda(), data['joints'])
159
+
160
+ # Build camera
161
+ yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
162
+ yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
163
+ yaws = [y + yaws_offset for y in yaws]
164
+ pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
165
+
166
+ exts = []
167
+ ints = []
168
+ for yaw, pitch in zip(yaws, pitch):
169
+ orig = torch.tensor([
170
+ np.sin(yaw) * np.cos(pitch),
171
+ np.cos(yaw) * np.cos(pitch),
172
+ np.sin(pitch),
173
+ ]).float().cuda() * 2
174
+ fov = torch.deg2rad(torch.tensor(40)).cuda()
175
+ extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
176
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
177
+ exts.append(extrinsics)
178
+ ints.append(intrinsics)
179
+
180
+ renderer = get_renderer(reps[0])
181
+ images = []
182
+ for representation in reps:
183
+ image = torch.zeros(3, 1024, 1024).cuda()
184
+ tile = [2, 2]
185
+ for j, (ext, intr) in enumerate(zip(exts, ints)):
186
+ res = renderer.render(representation, ext, intr)
187
+ image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
188
+ images.append(image)
189
+ images = torch.stack(images)
190
+
191
+ return images
192
+
193
+ def _get_skeleton(self, root, instance):
194
+ skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
195
+ skl_data = np.load(skeleton_path, allow_pickle=True)
196
+ joints, parents, skin = skl_data['joints'], skl_data['parents'], skl_data['skin']
197
+ parents[parents==None] = -1
198
+ parents = np.array(parents, dtype=np.int32)
199
+ ret = {
200
+ 'joints': torch.from_numpy(joints).float(),
201
+ 'parents': torch.from_numpy(parents).int(),
202
+ 'skin': torch.from_numpy(skin).float(),
203
+ }
204
+ if self.use_joint_num_cond:
205
+ ret['joints_num'] = int(joints.shape[0])
206
+ return ret
207
+
208
+ def _get_geo(self, root, instance):
209
+ skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
210
+ skl_data = np.load(skeleton_path, allow_pickle=True)
211
+ verts, face = np.array(skl_data['vertices'], dtype=np.float32), skl_data['faces']
212
+ mesh = {
213
+ "vertices" : torch.from_numpy(verts),
214
+ "faces" : torch.from_numpy(face),
215
+ }
216
+ geo = {"mesh": mesh}
217
+ if self.load_cubvh:
218
+ from cubvh import cuBVH
219
+ torch.cuda.set_device(self.local_rank)
220
+ cubvh_path = os.path.join(root, 'skeleton', instance, 'cubvh.pth')
221
+ if os.path.exists(cubvh_path):
222
+ cubvh = torch.load(cubvh_path, weights_only=False)
223
+ else:
224
+ cubvh = cuBVH(mesh["vertices"], mesh["faces"])
225
+ torch.save(cubvh, cubvh_path)
226
+ geo["cubvh"] = cubvh
227
+ return geo
228
+
229
+ def get_instance(self, root, instance):
230
+ data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
231
+ coords = torch.tensor(data['coords']).int()
232
+ feats = torch.tensor(data['feats']).float()
233
+ coords_skl = torch.tensor(data['coords_skl']).int()
234
+ feats_skl = torch.tensor(data['feats_skl']).float()
235
+ if self.normalization is not None:
236
+ feats = (feats - self.mean) / self.std
237
+ feats_skl = (feats_skl - self.mean_skl) / self.std_skl
238
+ return {
239
+ 'coords': coords,
240
+ 'feats': feats,
241
+ 'coords_skl': coords_skl,
242
+ 'feats_skl': feats_skl,
243
+ 'instance': instance,
244
+ **self._get_skeleton(root, instance),
245
+ **self._get_geo(root, instance),
246
+ }
247
+
248
+ @staticmethod
249
+ def collate_fn(batch, split_size=None):
250
+ if split_size is None:
251
+ group_idx = [list(range(len(batch)))]
252
+ else:
253
+ group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
254
+ packs = []
255
+ for group in group_idx:
256
+ sub_batch = [batch[i] for i in group]
257
+ pack = {}
258
+ coords = []
259
+ feats = []
260
+ coords_skl = []
261
+ feats_skl = []
262
+ layout = []
263
+ layout_skl = []
264
+ start = 0
265
+ start_skl = 0
266
+ for i, b in enumerate(sub_batch):
267
+ coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
268
+ feats.append(b['feats'])
269
+ coords_skl.append(torch.cat([torch.full((b['coords_skl'].shape[0], 1), i, dtype=torch.int32), b['coords_skl']], dim=-1))
270
+ feats_skl.append(b['feats_skl'])
271
+ layout.append(slice(start, start + b['coords'].shape[0]))
272
+ layout_skl.append(slice(start_skl, start_skl + b['coords_skl'].shape[0]))
273
+ start += b['coords'].shape[0]
274
+ start_skl += b['coords_skl'].shape[0]
275
+ coords = torch.cat(coords)
276
+ feats = torch.cat(feats)
277
+ pack['x_0'] = SparseTensor(
278
+ coords=coords,
279
+ feats=feats,
280
+ )
281
+ pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
282
+ pack['x_0'].register_spatial_cache('layout', layout)
283
+
284
+ coords_skl = torch.cat(coords_skl)
285
+ feats_skl = torch.cat(feats_skl)
286
+ pack['x_0_skl'] = SparseTensor(
287
+ coords=coords_skl,
288
+ feats=feats_skl,
289
+ )
290
+ pack['x_0_skl']._shape = torch.Size([len(group), *sub_batch[0]['feats_skl'].shape[1:]])
291
+ pack['x_0_skl'].register_spatial_cache('layout', layout_skl)
292
+
293
+ pack['joints'] = [b['joints'] for b in sub_batch]
294
+ pack['parents'] = [b['parents'] for b in sub_batch]
295
+ pack['skin'] = [b['skin'] for b in sub_batch]
296
+ if 'joints_num' in sub_batch[0]:
297
+ pack['joints_num'] = torch.tensor([b['joints_num'] for b in sub_batch], dtype=torch.long)
298
+
299
+ # collate other data
300
+ keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats', 'coords_skl', 'feats_skl', 'joints', 'parents', 'skin', 'joints_num']]
301
+ for k in keys:
302
+ if isinstance(sub_batch[0][k], torch.Tensor):
303
+ pack[k] = torch.stack([b[k] for b in sub_batch])
304
+ elif isinstance(sub_batch[0][k], list):
305
+ pack[k] = sum([b[k] for b in sub_batch], [])
306
+ else:
307
+ pack[k] = [b[k] for b in sub_batch]
308
+
309
+ packs.append(pack)
310
+
311
+ if split_size is None:
312
+ return packs[0]
313
+ return packs
314
+
315
+
316
+ class TextConditionedSLat(TextConditionedMixin, AniGenSLat):
317
+ """
318
+ Text conditioned structured latent dataset
319
+ """
320
+ pass
321
+
322
+
323
+ class AniGenImageConditionedSLat(ImageConditionedMixin, AniGenSLat):
324
+ """
325
+ Image conditioned structured latent dataset
326
+ """
327
+ pass
anigen/datasets/components.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ import os
4
+ import json
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class StandardDatasetBase(Dataset):
13
+ """
14
+ Base class for standard datasets.
15
+
16
+ Args:
17
+ roots (str): paths to the dataset
18
+ """
19
+
20
+ def __init__(self,
21
+ roots: str,
22
+ instances: List[str] = None,
23
+ **kwargs,
24
+ ):
25
+ super().__init__()
26
+ self.roots = roots.split(',')
27
+ self.instances = []
28
+ self.metadata = pd.DataFrame()
29
+
30
+ self._stats = {}
31
+ for root in self.roots:
32
+ key = os.path.basename(root)
33
+ self._stats[key] = {}
34
+ metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
35
+ self._stats[key]['Total'] = len(metadata)
36
+ metadata, stats = self.filter_metadata(metadata)
37
+ self._stats[key].update(stats)
38
+ self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
39
+ metadata.set_index('sha256', inplace=True)
40
+ self.metadata = pd.concat([self.metadata, metadata])
41
+
42
+ if instances is not None:
43
+ self.test_mode = False
44
+ self.instances = [inst for inst in self.instances if inst[1] in instances]
45
+
46
+ @abstractmethod
47
+ def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
52
+ pass
53
+
54
+ def __len__(self):
55
+ return len(self.instances)
56
+
57
+ def __getitem__(self, index) -> Dict[str, Any]:
58
+ try:
59
+ root, instance = self.instances[index]
60
+ return self.get_instance(root, instance)
61
+ except Exception as e:
62
+ print(e)
63
+ return self.__getitem__(np.random.randint(0, len(self)))
64
+
65
+ def __str__(self):
66
+ lines = []
67
+ lines.append(self.__class__.__name__)
68
+ lines.append(f' - Total instances: {len(self)}')
69
+ lines.append(f' - Sources:')
70
+ for key, stats in self._stats.items():
71
+ lines.append(f' - {key}:')
72
+ for k, v in stats.items():
73
+ lines.append(f' - {k}: {v}')
74
+ return '\n'.join(lines)
75
+
76
+
77
+ class TextConditionedMixin:
78
+ def __init__(self, roots, **kwargs):
79
+ super().__init__(roots, **kwargs)
80
+ self.captions = {}
81
+ for instance in self.instances:
82
+ sha256 = instance[1]
83
+ self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
84
+
85
+ def filter_metadata(self, metadata):
86
+ metadata, stats = super().filter_metadata(metadata)
87
+ metadata = metadata[metadata['captions'].notna()]
88
+ stats['With captions'] = len(metadata)
89
+ return metadata, stats
90
+
91
+ def get_instance(self, root, instance):
92
+ pack = super().get_instance(root, instance)
93
+ text = np.random.choice(self.captions[instance])
94
+ pack['cond'] = text
95
+ return pack
96
+
97
+
98
+ class ImageConditionedMixin:
99
+ def __init__(self, roots, *, image_size=518, **kwargs):
100
+ self.image_size = image_size
101
+ super().__init__(roots, **kwargs)
102
+
103
+ def filter_metadata(self, metadata):
104
+ metadata, stats = super().filter_metadata(metadata)
105
+ metadata = metadata[metadata[f'cond_rendered']]
106
+ stats['Cond rendered'] = len(metadata)
107
+ return metadata, stats
108
+
109
+ def get_instance(self, root, instance):
110
+ pack = super().get_instance(root, instance)
111
+
112
+ image_root = os.path.join(root, 'renders_cond', instance)
113
+ with open(os.path.join(image_root, 'transforms.json')) as f:
114
+ metadata = json.load(f)
115
+ n_views = len(metadata['frames'])
116
+ view = np.random.randint(n_views)
117
+ metadata = metadata['frames'][view]
118
+
119
+ image_path = os.path.join(image_root, metadata['file_path'])
120
+ image = Image.open(image_path)
121
+
122
+ alpha = np.array(image.getchannel(3))
123
+ bbox = np.array(alpha).nonzero()
124
+ bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
125
+ center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
126
+ hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
127
+ aug_size_ratio = 1.2
128
+ aug_hsize = hsize * aug_size_ratio
129
+ aug_center_offset = [0, 0]
130
+ aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
131
+ aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
132
+ image = image.crop(aug_bbox)
133
+
134
+ image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
135
+ alpha = image.getchannel(3)
136
+ image = image.convert('RGB')
137
+ image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
138
+ alpha = torch.tensor(np.array(alpha)).float() / 255.0
139
+ image = image * alpha.unsqueeze(0)
140
+ pack['cond'] = image
141
+
142
+ return pack
143
+
anigen/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __attributes = {
4
+ 'AniGenSparseStructureEncoder': 'anigen_sparse_structure_vae',
5
+ 'AniGenSparseStructureDecoder': 'anigen_sparse_structure_vae',
6
+ 'AniGenSparseStructureFlowModel': 'anigen_sparse_structure_flow',
7
+ 'AniGenSparseStructureFlowModelInpaint': 'anigen_sparse_structure_flow_inpaint',
8
+ 'AniGenElasticSLatEncoder': 'structured_latent_vae',
9
+ 'AniGenElasticSLatMeshDecoder': 'structured_latent_vae',
10
+ 'AniGenElasticSLatGaussianDecoder': 'structured_latent_vae',
11
+ 'AniGenSLatFlowModel': 'anigen_structured_latent_flow',
12
+ 'AniGenElasticSLatFlowModel': 'anigen_structured_latent_flow',
13
+ 'AniGenElasticSLatFlowModelOld': 'anigen_structured_latent_flow_old',
14
+ 'SkinAutoEncoder': 'structured_latent_vae',
15
+ }
16
+
17
+ __submodules = []
18
+
19
+ __all__ = list(__attributes.keys()) + __submodules
20
+
21
+ def __getattr__(name):
22
+ if name not in globals():
23
+ if name in __attributes:
24
+ module_name = __attributes[name]
25
+ module = importlib.import_module(f".{module_name}", __name__)
26
+ globals()[name] = getattr(module, name)
27
+ elif name in __submodules:
28
+ module = importlib.import_module(f".{name}", __name__)
29
+ globals()[name] = module
30
+ else:
31
+ raise AttributeError(f"module {__name__} has no attribute {name}")
32
+ return globals()[name]
33
+
34
+
35
+ def from_pretrained(path: str, **kwargs):
36
+ """
37
+ Load a model from a pretrained checkpoint.
38
+
39
+ Args:
40
+ path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
41
+ NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
42
+ **kwargs: Additional arguments for the model constructor.
43
+ """
44
+ import os
45
+ import json
46
+ from safetensors.torch import load_file
47
+ is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
48
+
49
+ if is_local:
50
+ config_file = f"{path}.json"
51
+ model_file = f"{path}.safetensors"
52
+ else:
53
+ print(f"{path}.json and {path}.safetensors not found, trying to download from Hugging Face Hub.")
54
+ from huggingface_hub import hf_hub_download
55
+ path_parts = path.split('/')
56
+ repo_id = f'{path_parts[0]}/{path_parts[1]}'
57
+ model_name = '/'.join(path_parts[2:])
58
+ config_file = hf_hub_download(repo_id, f"{model_name}.json")
59
+ model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
60
+
61
+ with open(config_file, 'r') as f:
62
+ config = json.load(f)
63
+ model = __getattr__(config['name'])(**config['args'], **kwargs)
64
+ model.load_state_dict(load_file(model_file))
65
+
66
+ return model
67
+
anigen/models/anigen_sparse_structure_flow.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from ..modules.utils import convert_module_to_f16, convert_module_to_f32
7
+ from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
8
+ from ..modules.spatial import patchify, unpatchify
9
+
10
+
11
+ class TimestepEmbedder(nn.Module):
12
+ """
13
+ Embeds scalar timesteps into vector representations.
14
+ """
15
+ def __init__(self, hidden_size, frequency_embedding_size=256):
16
+ super().__init__()
17
+ self.mlp = nn.Sequential(
18
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
19
+ nn.SiLU(),
20
+ nn.Linear(hidden_size, hidden_size, bias=True),
21
+ )
22
+ self.frequency_embedding_size = frequency_embedding_size
23
+
24
+ @staticmethod
25
+ def timestep_embedding(t, dim, max_period=10000):
26
+ """
27
+ Create sinusoidal timestep embeddings.
28
+
29
+ Args:
30
+ t: a 1-D Tensor of N indices, one per batch element.
31
+ These may be fractional.
32
+ dim: the dimension of the output.
33
+ max_period: controls the minimum frequency of the embeddings.
34
+
35
+ Returns:
36
+ an (N, D) Tensor of positional embeddings.
37
+ """
38
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
39
+ half = dim // 2
40
+ freqs = torch.exp(
41
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
42
+ ).to(device=t.device)
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ return embedding
48
+
49
+ def forward(self, t):
50
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
51
+ t_emb = self.mlp(t_freq)
52
+ return t_emb
53
+
54
+
55
+ class AniGenSparseStructureFlowModel(nn.Module):
56
+ def __init__(
57
+ self,
58
+ resolution: int,
59
+ in_channels: int,
60
+ in_channels_skl: int,
61
+ model_channels: int,
62
+ model_channels_skl: int,
63
+ cond_channels: int,
64
+ out_channels: int,
65
+ out_channels_skl: int,
66
+ num_blocks: int,
67
+ num_heads: Optional[int] = None,
68
+ num_head_channels: Optional[int] = 64,
69
+ mlp_ratio: float = 4,
70
+ patch_size: int = 2,
71
+ pe_mode: Literal["ape", "rope"] = "ape",
72
+ use_fp16: bool = False,
73
+ use_checkpoint: bool = False,
74
+ share_mod: bool = False,
75
+ qk_rms_norm: bool = False,
76
+ qk_rms_norm_cross: bool = False,
77
+ use_pretrain_branch: bool = True,
78
+ freeze_pretrain_branch: bool = True,
79
+ use_lora_ss: bool = False,
80
+ lora_lr_rate_ss: float = 0.1,
81
+ modules_to_freeze: Optional[List[str]] = ["blocks", "input_layer", "out_layer", "pos_emb", "t_embedder"],
82
+ adapter_ss_to_skl: bool = True,
83
+ adapter_skl_to_ss: bool = True,
84
+ predict_x0: bool = False,
85
+ predict_x0_skl: bool = False,
86
+ t_eps: float = 5e-2,
87
+ t_scale: float = 1e3,
88
+ z_is_global: bool = False,
89
+ z_skl_is_global: bool = False,
90
+ global_token_num: int = 1024,
91
+ global_token_num_skl: int = 1024,
92
+ cross_adapter_every: int = 4,
93
+ skl_cross_from_ss: bool = False,
94
+ ):
95
+ super().__init__()
96
+ self.resolution = resolution
97
+ self.in_channels = in_channels
98
+ self.in_channels_skl = in_channels_skl
99
+ self.model_channels = model_channels
100
+ self.model_channels_skl = model_channels_skl
101
+ self.cond_channels = cond_channels
102
+ self.out_channels = out_channels
103
+ self.out_channels_skl = out_channels_skl
104
+ self.num_blocks = num_blocks
105
+ self.num_heads = num_heads or model_channels // num_head_channels
106
+ self.mlp_ratio = mlp_ratio
107
+ self.patch_size = patch_size
108
+ self.pe_mode = pe_mode
109
+ self.use_fp16 = use_fp16
110
+ self.use_checkpoint = use_checkpoint
111
+ self.share_mod = share_mod
112
+ self.qk_rms_norm = qk_rms_norm
113
+ self.qk_rms_norm_cross = qk_rms_norm_cross
114
+ self.dtype = torch.float16 if use_fp16 else torch.float32
115
+ self.use_pretrain_branch = use_pretrain_branch
116
+ self.freeze_pretrain_branch = freeze_pretrain_branch or use_lora_ss
117
+ self.use_lora_ss = use_lora_ss
118
+ self.modules_to_freeze = modules_to_freeze
119
+ self.adapter_ss_to_skl = adapter_ss_to_skl
120
+ self.adapter_skl_to_ss = adapter_skl_to_ss
121
+ self.predict_x0 = predict_x0
122
+ self.predict_x0_skl = predict_x0_skl
123
+ self.t_eps = t_eps
124
+ self.t_scale = t_scale
125
+ self.z_is_global = z_is_global
126
+ self.z_skl_is_global = z_skl_is_global
127
+ self.global_token_num = global_token_num
128
+ self.global_token_num_skl = global_token_num_skl
129
+ self.cross_adapter_every = int(cross_adapter_every)
130
+ self.skl_cross_from_ss = skl_cross_from_ss
131
+
132
+ self.t_embedder = TimestepEmbedder(model_channels)
133
+ self.t_embedder_skl = TimestepEmbedder(model_channels_skl)
134
+ if share_mod:
135
+ self.adaLN_modulation = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
138
+ )
139
+ self.adaLN_modulation_skl = nn.Sequential(
140
+ nn.SiLU(),
141
+ nn.Linear(model_channels_skl, 6 * model_channels_skl, bias=True)
142
+ )
143
+
144
+ if pe_mode == "ape":
145
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
146
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
147
+ if self.z_is_global:
148
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 1)
149
+ pos_emb = pos_embedder(torch.arange(self.global_token_num, device=self.device)[:, None])
150
+ else:
151
+ pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
152
+ pos_emb = pos_embedder(coords)
153
+ self.register_buffer("pos_emb", pos_emb)
154
+ if self.z_skl_is_global:
155
+ pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl, 1)
156
+ pos_emb_skl = pos_embedder_skl(torch.arange(self.global_token_num_skl, device=self.device)[:, None])
157
+ else:
158
+ pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl, 3)
159
+ pos_emb_skl = pos_embedder_skl(coords)
160
+ self.register_buffer("pos_emb_skl", pos_emb_skl)
161
+
162
+ self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
163
+ self.input_layer_skl = nn.Linear(in_channels_skl * patch_size**3, model_channels_skl)
164
+
165
+ shallow = max(1, num_blocks // 3)
166
+ middle = max(1, num_blocks // 3 * 2)
167
+ self.blocks = nn.ModuleList([
168
+ ModulatedTransformerCrossBlock(
169
+ model_channels,
170
+ cond_channels,
171
+ num_heads=self.num_heads,
172
+ mlp_ratio=self.mlp_ratio,
173
+ attn_mode='full',
174
+ use_checkpoint=self.use_checkpoint,
175
+ use_rope=(pe_mode == "rope"),
176
+ share_mod=share_mod,
177
+ qk_rms_norm=self.qk_rms_norm,
178
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
179
+ use_lora_self=self.use_lora_ss and idx >= middle,
180
+ lora_rank_self=8,
181
+ use_lora_cross=self.use_lora_ss,
182
+ lora_rank_cross=8+(idx // shallow)*8,
183
+ lora_lr_rate=lora_lr_rate_ss,
184
+ )
185
+ for idx in range(num_blocks)
186
+ ])
187
+ self.blocks_skl = nn.ModuleList([
188
+ ModulatedTransformerCrossBlock(
189
+ model_channels_skl,
190
+ cond_channels if not self.skl_cross_from_ss else model_channels,
191
+ num_heads=self.num_heads,
192
+ mlp_ratio=self.mlp_ratio,
193
+ attn_mode='full',
194
+ use_checkpoint=self.use_checkpoint,
195
+ use_rope=(pe_mode == "rope"),
196
+ share_mod=share_mod,
197
+ qk_rms_norm=self.qk_rms_norm,
198
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
199
+ use_context_norm=self.skl_cross_from_ss,
200
+ )
201
+ for _ in range(num_blocks)
202
+ ])
203
+
204
+ # When using global tokens, ss and skl token counts may differ, so we use cross-attention
205
+ # for information exchange at a configurable frequency.
206
+ self.use_cross_adapter = (self.z_is_global or self.z_skl_is_global) and (
207
+ self.adapter_ss_to_skl or self.adapter_skl_to_ss
208
+ )
209
+
210
+ if self.adapter_ss_to_skl and not self.use_cross_adapter:
211
+ self.adapter_ss_to_skl_layers = nn.ModuleList([
212
+ nn.Linear(model_channels, model_channels_skl) for _ in range(num_blocks)
213
+ ])
214
+ if self.adapter_skl_to_ss and not self.use_cross_adapter:
215
+ self.adapter_skl_to_ss_layers = nn.ModuleList([
216
+ nn.Linear(model_channels_skl, model_channels) for _ in range(num_blocks)
217
+ ])
218
+
219
+ self.cross_adapter_every = max(1, self.cross_adapter_every)
220
+ self.cross_block_indices: List[int] = [
221
+ idx for idx in range(num_blocks) if (idx + 1) % self.cross_adapter_every == 0
222
+ ]
223
+ if self.use_cross_adapter and len(self.cross_block_indices) == 0 and num_blocks > 0:
224
+ self.cross_block_indices = [num_blocks - 1]
225
+ if self.use_cross_adapter and len(self.cross_block_indices) > 0:
226
+ if self.adapter_ss_to_skl:
227
+ self.cross_blocks_ss_to_skl = nn.ModuleList([
228
+ ModulatedTransformerCrossBlock(
229
+ model_channels_skl,
230
+ model_channels,
231
+ num_heads=self.num_heads,
232
+ mlp_ratio=self.mlp_ratio,
233
+ attn_mode='full',
234
+ use_checkpoint=self.use_checkpoint,
235
+ use_rope=(pe_mode == "rope"),
236
+ share_mod=share_mod,
237
+ qk_rms_norm=self.qk_rms_norm,
238
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
239
+ )
240
+ for _ in self.cross_block_indices
241
+ ])
242
+ self.cross_blocks_ss_to_skl_out = nn.ModuleList([
243
+ nn.Linear(model_channels_skl, model_channels_skl, bias=True)
244
+ for _ in self.cross_block_indices
245
+ ])
246
+ if self.adapter_skl_to_ss:
247
+ self.cross_blocks_skl_to_ss = nn.ModuleList([
248
+ ModulatedTransformerCrossBlock(
249
+ model_channels,
250
+ model_channels_skl,
251
+ num_heads=self.num_heads,
252
+ mlp_ratio=self.mlp_ratio,
253
+ attn_mode='full',
254
+ use_checkpoint=self.use_checkpoint,
255
+ use_rope=(pe_mode == "rope"),
256
+ share_mod=share_mod,
257
+ qk_rms_norm=self.qk_rms_norm,
258
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
259
+ )
260
+ for _ in self.cross_block_indices
261
+ ])
262
+ self.cross_blocks_skl_to_ss_out = nn.ModuleList([
263
+ nn.Linear(model_channels, model_channels, bias=True)
264
+ for _ in self.cross_block_indices
265
+ ])
266
+
267
+ self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
268
+ self.out_layer_skl = nn.Linear(model_channels_skl, out_channels_skl * patch_size**3)
269
+
270
+ self.initialize_weights()
271
+ if use_fp16:
272
+ self.convert_to_fp16()
273
+
274
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
275
+ for module in modules_to_freeze:
276
+ if hasattr(self, module):
277
+ mod = getattr(self, module)
278
+ if isinstance(mod, nn.ModuleList):
279
+ for m in mod:
280
+ for name, param in m.named_parameters():
281
+ if 'lora' not in name:
282
+ param.requires_grad = False
283
+ elif isinstance(mod, nn.Module):
284
+ for name, param in mod.named_parameters():
285
+ if 'lora' not in name:
286
+ param.requires_grad = False
287
+ elif isinstance(mod, torch.Tensor):
288
+ if mod.requires_grad:
289
+ mod.requires_grad = False
290
+
291
+ @property
292
+ def device(self) -> torch.device:
293
+ """
294
+ Return the device of the model.
295
+ """
296
+ return next(self.parameters()).device
297
+
298
+ def convert_to_fp16(self) -> None:
299
+ """
300
+ Convert the torso of the model to float16.
301
+ """
302
+ self.blocks.apply(convert_module_to_f16)
303
+ self.blocks_skl.apply(convert_module_to_f16)
304
+ if hasattr(self, "adapter_ss_to_skl_layers"):
305
+ self.adapter_ss_to_skl_layers.apply(convert_module_to_f16)
306
+ if hasattr(self, "adapter_skl_to_ss_layers"):
307
+ self.adapter_skl_to_ss_layers.apply(convert_module_to_f16)
308
+ if getattr(self, "use_cross_adapter", False):
309
+ if hasattr(self, "cross_blocks_ss_to_skl"):
310
+ self.cross_blocks_ss_to_skl.apply(convert_module_to_f16)
311
+ self.cross_blocks_ss_to_skl_out.apply(convert_module_to_f16)
312
+ if hasattr(self, "cross_blocks_skl_to_ss"):
313
+ self.cross_blocks_skl_to_ss.apply(convert_module_to_f16)
314
+ self.cross_blocks_skl_to_ss_out.apply(convert_module_to_f16)
315
+
316
+ def convert_to_fp32(self) -> None:
317
+ """
318
+ Convert the torso of the model to float32.
319
+ """
320
+ self.blocks.apply(convert_module_to_f32)
321
+ self.blocks_skl.apply(convert_module_to_f32)
322
+ if hasattr(self, "adapter_ss_to_skl_layers"):
323
+ self.adapter_ss_to_skl_layers.apply(convert_module_to_f32)
324
+ if hasattr(self, "adapter_skl_to_ss_layers"):
325
+ self.adapter_skl_to_ss_layers.apply(convert_module_to_f32)
326
+ if getattr(self, "use_cross_adapter", False):
327
+ if hasattr(self, "cross_blocks_ss_to_skl"):
328
+ self.cross_blocks_ss_to_skl.apply(convert_module_to_f32)
329
+ self.cross_blocks_ss_to_skl_out.apply(convert_module_to_f32)
330
+ if hasattr(self, "cross_blocks_skl_to_ss"):
331
+ self.cross_blocks_skl_to_ss.apply(convert_module_to_f32)
332
+ self.cross_blocks_skl_to_ss_out.apply(convert_module_to_f32)
333
+
334
+ def initialize_weights(self) -> None:
335
+ # Initialize transformer layers:
336
+ def _basic_init(module):
337
+ if isinstance(module, nn.Linear):
338
+ torch.nn.init.xavier_uniform_(module.weight)
339
+ if module.bias is not None:
340
+ nn.init.constant_(module.bias, 0)
341
+ self.apply(_basic_init)
342
+
343
+ # Initialize timestep embedding MLP:
344
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
345
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
346
+ nn.init.normal_(self.t_embedder_skl.mlp[0].weight, std=0.02)
347
+ nn.init.normal_(self.t_embedder_skl.mlp[2].weight, std=0.02)
348
+
349
+ # Zero-out adaLN modulation layers in DiT blocks:
350
+ if self.share_mod:
351
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
352
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
353
+ nn.init.constant_(self.adaLN_modulation_skl[-1].weight, 0)
354
+ nn.init.constant_(self.adaLN_modulation_skl[-1].bias, 0)
355
+ else:
356
+ for block in self.blocks:
357
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
358
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
359
+ for block in self.blocks_skl:
360
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
361
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
362
+
363
+ # Zero-out output layers:
364
+ nn.init.constant_(self.out_layer.weight, 0)
365
+ nn.init.constant_(self.out_layer.bias, 0)
366
+ nn.init.constant_(self.out_layer_skl.weight, 0)
367
+ nn.init.constant_(self.out_layer_skl.bias, 0)
368
+
369
+ # Zero-out adapter layers if exist
370
+ if hasattr(self, "adapter_ss_to_skl_layers"):
371
+ for layer in self.adapter_ss_to_skl_layers:
372
+ nn.init.constant_(layer.weight, 0)
373
+ nn.init.constant_(layer.bias, 0)
374
+ if hasattr(self, "adapter_skl_to_ss_layers"):
375
+ for layer in self.adapter_skl_to_ss_layers:
376
+ nn.init.constant_(layer.weight, 0)
377
+ nn.init.constant_(layer.bias, 0)
378
+
379
+ # Zero-out cross adapter output projections (so we can safely finetune from pretrained ckpt)
380
+ if getattr(self, "use_cross_adapter", False):
381
+ if hasattr(self, "cross_blocks_ss_to_skl_out"):
382
+ for layer in self.cross_blocks_ss_to_skl_out:
383
+ nn.init.constant_(layer.weight, 0)
384
+ nn.init.constant_(layer.bias, 0)
385
+ if hasattr(self, "cross_blocks_skl_to_ss_out"):
386
+ for layer in self.cross_blocks_skl_to_ss_out:
387
+ nn.init.constant_(layer.weight, 0)
388
+ nn.init.constant_(layer.bias, 0)
389
+
390
+ def forward(self, x: torch.Tensor, x_skl: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, **kwargs) -> torch.Tensor:
391
+ if not self.z_is_global:
392
+ assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
393
+ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
394
+ if not self.z_skl_is_global:
395
+ assert [*x_skl.shape] == [x_skl.shape[0], self.in_channels_skl, *[self.resolution] * 3], \
396
+ f"Input shape mismatch, got {x_skl.shape}, expected {[x_skl.shape[0], self.in_channels_skl, *[self.resolution] * 3]}"
397
+
398
+ if self.predict_x0:
399
+ xt = x.clone()
400
+ if self.predict_x0_skl:
401
+ xt_skl = x_skl.clone()
402
+
403
+ if not self.z_is_global:
404
+ h = patchify(x, self.patch_size)
405
+ h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
406
+ else:
407
+ h = x
408
+ if not self.z_skl_is_global:
409
+ h_skl = patchify(x_skl, self.patch_size)
410
+ h_skl = h_skl.view(*h_skl.shape[:2], -1).permute(0, 2, 1).contiguous()
411
+ else:
412
+ h_skl = x_skl
413
+
414
+ h = self.input_layer(h)
415
+ h = h + self.pos_emb[None]
416
+ h_skl = self.input_layer_skl(h_skl)
417
+ h_skl = h_skl + self.pos_emb_skl[None]
418
+
419
+ t_emb = self.t_embedder(t)
420
+ t_emb_skl = self.t_embedder_skl(t)
421
+ if self.share_mod:
422
+ t_emb = self.adaLN_modulation(t_emb)
423
+ t_emb_skl = self.adaLN_modulation_skl(t_emb_skl)
424
+ t_emb = t_emb.type(self.dtype)
425
+ t_emb_skl = t_emb_skl.type(self.dtype)
426
+
427
+ h = h.type(self.dtype)
428
+ h_skl = h_skl.type(self.dtype)
429
+ cond = cond.type(self.dtype)
430
+
431
+ cross_pos_to_idx = None
432
+ if self.use_cross_adapter and len(self.cross_block_indices) > 0:
433
+ cross_pos_to_idx = {bidx: cidx for cidx, bidx in enumerate(self.cross_block_indices)}
434
+
435
+ for idx, block, block_skl in zip(range(len(self.blocks)), self.blocks, self.blocks_skl):
436
+ f = block(h, t_emb, cond)
437
+ f_skl = block_skl(h_skl, t_emb_skl, h if self.skl_cross_from_ss else cond)
438
+
439
+ if self.use_cross_adapter and cross_pos_to_idx is not None and idx in cross_pos_to_idx:
440
+ cidx = cross_pos_to_idx[idx]
441
+ if self.adapter_ss_to_skl:
442
+ out_skl = self.cross_blocks_ss_to_skl[cidx](f_skl, t_emb_skl, f)
443
+ h_skl = f_skl + self.cross_blocks_ss_to_skl_out[cidx](out_skl - f_skl)
444
+ else:
445
+ h_skl = f_skl
446
+
447
+ if self.adapter_skl_to_ss:
448
+ out = self.cross_blocks_skl_to_ss[cidx](f, t_emb, f_skl)
449
+ h = f + self.cross_blocks_skl_to_ss_out[cidx](out - f)
450
+ else:
451
+ h = f
452
+ else:
453
+ # Non-global (or no cross block at this idx): keep previous behavior.
454
+ if self.adapter_ss_to_skl and (not self.use_cross_adapter):
455
+ h_skl = f_skl + self.adapter_ss_to_skl_layers[idx](f)
456
+ else:
457
+ h_skl = f_skl
458
+
459
+ if self.adapter_skl_to_ss and (not self.use_cross_adapter):
460
+ h = f + self.adapter_skl_to_ss_layers[idx](f_skl)
461
+ else:
462
+ h = f
463
+ h = h.type(x.dtype)
464
+ h = F.layer_norm(h, h.shape[-1:])
465
+ h = self.out_layer(h)
466
+ h_skl = h_skl.type(x_skl.dtype)
467
+ h_skl = F.layer_norm(h_skl, h_skl.shape[-1:])
468
+ h_skl = self.out_layer_skl(h_skl)
469
+
470
+ if not self.z_is_global:
471
+ h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
472
+ h = unpatchify(h, self.patch_size).contiguous()
473
+
474
+ if not self.z_skl_is_global:
475
+ h_skl = h_skl.permute(0, 2, 1).view(h_skl.shape[0], h_skl.shape[2], *[self.resolution // self.patch_size] * 3)
476
+ h_skl = unpatchify(h_skl, self.patch_size).contiguous()
477
+
478
+ if self.predict_x0:
479
+ t_normalized = t / self.t_scale
480
+ factor = (1 / t_normalized.clamp_min(self.t_eps)).reshape([t.shape[0], *([1] * (x.dim() - 1))])
481
+ h = (xt - h) * factor
482
+ if self.predict_x0_skl:
483
+ t_normalized = t / self.t_scale
484
+ factor = (1 / t_normalized.clamp_min(self.t_eps)).reshape([t.shape[0], *([1] * (x_skl.dim() - 1))])
485
+ h_skl = (xt_skl - h_skl) * factor
486
+
487
+ return h, h_skl
anigen/models/anigen_sparse_structure_vae.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ..modules.norm import GroupNorm32, ChannelLayerNorm32
6
+ from ..modules.spatial import pixel_shuffle_3d
7
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+ from ..modules.transformer import FeedForwardNet, TransformerBlock, TransformerCrossBlock, AbsolutePositionEmbedder
9
+
10
+
11
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
12
+ """
13
+ Return a normalization layer.
14
+ """
15
+ if norm_type == "group":
16
+ return GroupNorm32(32, *args, **kwargs)
17
+ elif norm_type == "layer":
18
+ return ChannelLayerNorm32(*args, **kwargs)
19
+ else:
20
+ raise ValueError(f"Invalid norm type {norm_type}")
21
+
22
+
23
+ class ResBlock3d(nn.Module):
24
+ def __init__(
25
+ self,
26
+ channels: int,
27
+ out_channels: Optional[int] = None,
28
+ norm_type: Literal["group", "layer"] = "layer",
29
+ ):
30
+ super().__init__()
31
+ self.channels = channels
32
+ self.out_channels = out_channels or channels
33
+
34
+ self.norm1 = norm_layer(norm_type, channels)
35
+ self.norm2 = norm_layer(norm_type, self.out_channels)
36
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
37
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
38
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ h = self.norm1(x)
42
+ h = F.silu(h)
43
+ h = self.conv1(h)
44
+ h = self.norm2(h)
45
+ h = F.silu(h)
46
+ h = self.conv2(h)
47
+ h = h + self.skip_connection(x)
48
+ return h
49
+
50
+
51
+ class DownsampleBlock3d(nn.Module):
52
+ def __init__(
53
+ self,
54
+ in_channels: int,
55
+ out_channels: int,
56
+ mode: Literal["conv", "avgpool"] = "conv",
57
+ ):
58
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
59
+
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ self.out_channels = out_channels
63
+
64
+ if mode == "conv":
65
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
66
+ elif mode == "avgpool":
67
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ if hasattr(self, "conv"):
71
+ return self.conv(x)
72
+ else:
73
+ return F.avg_pool3d(x, 2)
74
+
75
+
76
+ class UpsampleBlock3d(nn.Module):
77
+ def __init__(
78
+ self,
79
+ in_channels: int,
80
+ out_channels: int,
81
+ mode: Literal["conv", "nearest"] = "conv",
82
+ ):
83
+ assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
84
+
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ self.out_channels = out_channels
88
+
89
+ if mode == "conv":
90
+ self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
91
+ elif mode == "nearest":
92
+ assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ if hasattr(self, "conv"):
96
+ x = self.conv(x)
97
+ return pixel_shuffle_3d(x, 2)
98
+ else:
99
+ return F.interpolate(x, scale_factor=2, mode="nearest")
100
+
101
+
102
+ class AniGenSparseStructureEncoder(nn.Module):
103
+ """
104
+ Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
105
+
106
+ Args:
107
+ in_channels (int): Channels of the input.
108
+ latent_channels (int): Channels of the latent representation.
109
+ num_res_blocks (int): Number of residual blocks at each resolution.
110
+ channels (List[int]): Channels of the encoder blocks.
111
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
112
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
113
+ use_fp16 (bool): Whether to use FP16.
114
+ """
115
+ def __init__(
116
+ self,
117
+ in_channels: int,
118
+ in_channels_skl: int,
119
+ latent_channels: int,
120
+ latent_channels_skl: int,
121
+ num_res_blocks: int,
122
+ channels: List[int],
123
+ num_res_blocks_middle: int = 2,
124
+ norm_type: Literal["group", "layer"] = "layer",
125
+ use_fp16: bool = False,
126
+ encode_global: bool = False,
127
+ global_token_num: int = 1024,
128
+ encode_global_skl: bool = True,
129
+ global_token_num_skl: int = 1024,
130
+ use_pretrain_branch: bool = True,
131
+ freeze_pretrain_branch: bool = True,
132
+ modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"],
133
+ latent_denoising: bool = False,
134
+ latent_denoising_skl: bool = True,
135
+ normalize_z: bool = False,
136
+ normalize_z_skl: bool = True,
137
+ normalize_scale: float = 1.0
138
+ ):
139
+ super().__init__()
140
+ self.in_channels = in_channels
141
+ self.in_channels_skl = in_channels_skl
142
+ self.latent_channels = latent_channels
143
+ self.latent_channels_skl = latent_channels_skl
144
+ self.num_res_blocks = num_res_blocks
145
+ self.channels = channels
146
+ self.num_res_blocks_middle = num_res_blocks_middle
147
+ self.norm_type = norm_type
148
+ self.use_fp16 = use_fp16
149
+ self.dtype = torch.float16 if use_fp16 else torch.float32
150
+ self.encode_global = encode_global
151
+ self.global_token_num = global_token_num
152
+ self.encode_global_skl = encode_global_skl
153
+ self.global_token_num_skl = global_token_num_skl
154
+ self.use_pretrain_branch = use_pretrain_branch
155
+ self.freeze_pretrain_branch = freeze_pretrain_branch
156
+ self.latent_denoising = latent_denoising
157
+ self.latent_denoising_skl = latent_denoising_skl
158
+ self.normalize_latent = normalize_z and latent_denoising
159
+ self.normalize_latent_skl = normalize_z_skl and latent_denoising_skl
160
+ self.normalize_scale = normalize_scale
161
+
162
+ self.input_layer = nn.Conv3d(self.in_channels, channels[0], 3, padding=1)
163
+ self.input_layer_skl = nn.Conv3d(self.in_channels_skl, channels[0], 3, padding=1)
164
+
165
+ self.blocks = nn.ModuleList([])
166
+ self.blocks_skl = nn.ModuleList([])
167
+ for i, ch in enumerate(channels):
168
+ self.blocks.extend([
169
+ ResBlock3d(ch, ch)
170
+ for _ in range(num_res_blocks)
171
+ ])
172
+ self.blocks_skl.extend([
173
+ ResBlock3d(ch, ch)
174
+ for _ in range(num_res_blocks)
175
+ ])
176
+ if i < len(channels) - 1:
177
+ self.blocks.append(
178
+ DownsampleBlock3d(ch, channels[i+1])
179
+ )
180
+ self.blocks_skl.append(
181
+ DownsampleBlock3d(ch, channels[i+1])
182
+ )
183
+
184
+ self.middle_block = nn.Sequential(*[
185
+ ResBlock3d(channels[-1], channels[-1])
186
+ for _ in range(num_res_blocks_middle)
187
+ ])
188
+ self.middle_block_skl = nn.Sequential(*[
189
+ ResBlock3d(channels[-1] if _ == 0 else channels[-1], channels[-1])
190
+ for _ in range(num_res_blocks_middle)
191
+ ])
192
+
193
+ if self.encode_global:
194
+ # Initial Tokens and PE
195
+ self.init_tokens_ss = nn.Parameter(torch.zeros(1, global_token_num, channels[-1]))
196
+ pos_embedder = AbsolutePositionEmbedder(channels[-1], 1)
197
+ coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1)
198
+ tokens_pos_emb = pos_embedder(coords)
199
+ self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb)
200
+ # Grids PE
201
+ upsample_factor = 2 ** (len(channels) - 1)
202
+ self.base_size_ss = 64 // upsample_factor
203
+ pos_embedder = AbsolutePositionEmbedder(channels[-1], 3)
204
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij')
205
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
206
+ grid_pos_emb = pos_embedder(coords)
207
+ self.register_buffer("grid_pos_emb_ss", grid_pos_emb)
208
+ # Token projection layer
209
+ self.token_proj_ss = nn.Linear(channels[-1]*2, channels[-1])
210
+
211
+ # Out layers
212
+ self.out_layer = nn.ModuleList(
213
+ [TransformerCrossBlock(
214
+ channels=channels[-1],
215
+ ctx_channels=channels[-1]*2,
216
+ out_channels=channels[-1],
217
+ num_heads=16,
218
+ attn_mode="full",
219
+ qkv_bias=False,
220
+ x_is_query=False)] +
221
+ [TransformerBlock(
222
+ channels=channels[-1],
223
+ out_channels=channels[-1],
224
+ num_heads=16,
225
+ attn_mode="full",
226
+ qkv_bias=False,
227
+ ) for _ in range(4)] +
228
+ [FeedForwardNet(
229
+ channels=channels[-1],
230
+ out_channels=latent_channels*2 if not self.latent_denoising else latent_channels)]
231
+ )
232
+ else:
233
+ self.out_layer = nn.Sequential(
234
+ norm_layer(norm_type, channels[-1]),
235
+ nn.SiLU(),
236
+ nn.Conv3d(channels[-1], latent_channels*2 if not self.latent_denoising else latent_channels, 3, padding=1)
237
+ )
238
+
239
+ if self.encode_global_skl:
240
+ # Initial Tokens and PE
241
+ self.init_tokens = nn.Parameter(torch.zeros(1, global_token_num_skl, channels[-1]))
242
+ pos_embedder = AbsolutePositionEmbedder(channels[-1], 1)
243
+ coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1)
244
+ tokens_pos_emb = pos_embedder(coords)
245
+ self.register_buffer('tokens_pos_emb', tokens_pos_emb)
246
+ # Grids PE
247
+ upsample_factor = 2 ** (len(channels) - 1)
248
+ self.base_size = 64 // upsample_factor
249
+ pos_embedder = AbsolutePositionEmbedder(channels[-1], 3)
250
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij')
251
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
252
+ grid_pos_emb = pos_embedder(coords)
253
+ self.register_buffer("grid_pos_emb", grid_pos_emb)
254
+ # Token projection layer
255
+ self.token_proj = nn.Linear(channels[-1]*2, channels[-1])
256
+
257
+ # Out layers
258
+ self.out_layer_skl = nn.ModuleList(
259
+ [TransformerCrossBlock(
260
+ channels=channels[-1],
261
+ ctx_channels=channels[-1]*2,
262
+ out_channels=channels[-1],
263
+ num_heads=16,
264
+ attn_mode="full",
265
+ qkv_bias=False,
266
+ x_is_query=False)] +
267
+ [TransformerBlock(
268
+ channels=channels[-1],
269
+ out_channels=channels[-1],
270
+ num_heads=16,
271
+ attn_mode="full",
272
+ qkv_bias=False,
273
+ ) for _ in range(4)] +
274
+ [FeedForwardNet(
275
+ channels=channels[-1],
276
+ out_channels=latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl)]
277
+ )
278
+ else:
279
+ self.out_layer_skl = nn.Sequential(
280
+ norm_layer(norm_type, channels[-1]),
281
+ nn.SiLU(),
282
+ nn.Conv3d(channels[-1], latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl, 3, padding=1)
283
+ )
284
+
285
+ self.initialize_weights()
286
+ if use_fp16:
287
+ self.convert_to_fp16()
288
+
289
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
290
+ # Freeze: self.input_layer, self.blocks, self.middle_block, self.out_layer
291
+ for module in modules_to_freeze:
292
+ if hasattr(self, module):
293
+ mod = getattr(self, module)
294
+ if isinstance(mod, nn.ModuleList):
295
+ for m in mod:
296
+ for param in m.parameters():
297
+ param.requires_grad = False
298
+ else:
299
+ for param in mod.parameters():
300
+ param.requires_grad = False
301
+
302
+ @property
303
+ def device(self) -> torch.device:
304
+ """
305
+ Return the device of the model.
306
+ """
307
+ return next(self.parameters()).device
308
+
309
+ def convert_to_fp16(self) -> None:
310
+ """
311
+ Convert the torso of the model to float16.
312
+ """
313
+ self.use_fp16 = True
314
+ self.dtype = torch.float16
315
+ self.blocks.apply(convert_module_to_f16)
316
+ self.middle_block.apply(convert_module_to_f16)
317
+ self.blocks_skl.apply(convert_module_to_f16)
318
+ self.middle_block_skl.apply(convert_module_to_f16)
319
+ if self.encode_global_skl:
320
+ self.token_proj.apply(convert_module_to_f16)
321
+ self.out_layer_skl.apply(convert_module_to_f16)
322
+ if self.encode_global:
323
+ self.token_proj_ss.apply(convert_module_to_f16)
324
+ self.out_layer.apply(convert_module_to_f16)
325
+
326
+ def convert_to_fp32(self) -> None:
327
+ """
328
+ Convert the torso of the model to float32.
329
+ """
330
+ self.use_fp16 = False
331
+ self.dtype = torch.float32
332
+ self.blocks.apply(convert_module_to_f32)
333
+ self.middle_block.apply(convert_module_to_f32)
334
+ self.blocks_skl.apply(convert_module_to_f32)
335
+ self.middle_block_skl.apply(convert_module_to_f32)
336
+ if self.encode_global_skl:
337
+ self.token_proj.apply(convert_module_to_f32)
338
+ self.out_layer_skl.apply(convert_module_to_f32)
339
+ if self.encode_global:
340
+ self.token_proj_ss.apply(convert_module_to_f32)
341
+ self.out_layer.apply(convert_module_to_f32)
342
+
343
+ def initialize_weights(self) -> None:
344
+ # Initialize transformer layers:
345
+ def _basic_init(module):
346
+ if isinstance(module, nn.Linear):
347
+ torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear')
348
+ if module.bias is not None:
349
+ nn.init.constant_(module.bias, 0)
350
+ self.apply(_basic_init)
351
+
352
+ def forward(self, x: torch.Tensor, x_skl: torch.Tensor = None, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
353
+ h = self.input_layer(x)
354
+ h = h.type(self.dtype)
355
+ h_skl = self.input_layer_skl(x_skl)
356
+ h_skl = h_skl.type(self.dtype)
357
+
358
+ for block, block_skl in zip(self.blocks, self.blocks_skl):
359
+ h_skl = block_skl(h_skl)
360
+ h = block(h)
361
+ h_skl = self.middle_block_skl(h_skl)
362
+ h = self.middle_block(h)
363
+
364
+ if self.encode_global:
365
+ B, C, D, H, W = h.shape
366
+ h = h.view(B, C, D*H*W).permute(0, 2, 1) # B, N, C
367
+ h = torch.cat([h, self.grid_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(h.dtype)
368
+ init_tokens = torch.cat([self.init_tokens_ss, self.tokens_pos_emb_ss[None].expand_as(self.init_tokens_ss)], dim=-1).type(h.dtype)
369
+ tokens = self.token_proj_ss(init_tokens.expand(B, -1, -1))
370
+ h = self.out_layer[0](tokens, h) # B, global_token_num, C
371
+ for layer in self.out_layer[1:]:
372
+ h = layer(h)
373
+ h = h.type(x.dtype)
374
+ if self.latent_denoising:
375
+ if self.normalize_latent:
376
+ h = nn.functional.normalize(h, dim=-1) * self.normalize_scale
377
+ mean = h
378
+ logvar = torch.zeros_like(h)
379
+ else:
380
+ mean, logvar = h.chunk(2, dim=2) # B, global_token_num, C
381
+ if sample_posterior and not self.latent_denoising:
382
+ std = torch.exp(0.5 * logvar)
383
+ z = mean + std * torch.randn_like(std)
384
+ else:
385
+ z = mean
386
+ else:
387
+ h = h.type(x.dtype)
388
+ h = self.out_layer(h)
389
+ if self.latent_denoising:
390
+ if self.normalize_latent:
391
+ h = nn.functional.normalize(h, dim=1) * self.normalize_scale
392
+ mean = h
393
+ logvar = torch.zeros_like(h)
394
+ else:
395
+ mean, logvar = h.chunk(2, dim=1)
396
+ if sample_posterior and not self.latent_denoising:
397
+ std = torch.exp(0.5 * logvar)
398
+ z = mean + std * torch.randn_like(std)
399
+ else:
400
+ z = mean
401
+
402
+ if self.encode_global_skl:
403
+ B, C, D, H, W = h_skl.shape
404
+ h_skl = h_skl.view(B, C, D*H*W).permute(0, 2, 1) # B, N, C
405
+ h_skl = torch.cat([h_skl, self.grid_pos_emb[None].expand(B, -1, -1)], dim=-1).type(h_skl.dtype)
406
+ init_tokens = torch.cat([self.init_tokens, self.tokens_pos_emb[None].expand_as(self.init_tokens)], dim=-1).type(h_skl.dtype)
407
+ tokens = self.token_proj(init_tokens.expand(B, -1, -1))
408
+ h_skl = self.out_layer_skl[0](tokens, h_skl) # B, global_token_num_skl, C
409
+ for layer in self.out_layer_skl[1:]:
410
+ h_skl = layer(h_skl)
411
+ h_skl = h_skl.type(x_skl.dtype)
412
+ if self.latent_denoising_skl:
413
+ if self.normalize_latent_skl:
414
+ h_skl = nn.functional.normalize(h_skl, dim=-1) * self.normalize_scale
415
+ mean_skl = h_skl
416
+ logvar_skl = torch.zeros_like(h_skl)
417
+ else:
418
+ mean_skl, logvar_skl = h_skl.chunk(2, dim=2) # B, global_token_num_skl, C
419
+ if sample_posterior and not self.latent_denoising_skl:
420
+ std_skl = torch.exp(0.5 * logvar_skl)
421
+ z_skl = mean_skl + std_skl * torch.randn_like(std_skl)
422
+ else:
423
+ z_skl = mean_skl
424
+ else:
425
+ h_skl = h_skl.type(x_skl.dtype)
426
+ h_skl = self.out_layer_skl(h_skl)
427
+ if self.latent_denoising_skl:
428
+ if self.normalize_latent_skl:
429
+ h_skl = nn.functional.normalize(h_skl, dim=1) * self.normalize_scale
430
+ mean_skl = h_skl
431
+ logvar_skl = torch.zeros_like(h_skl)
432
+ else:
433
+ mean_skl, logvar_skl = h_skl.chunk(2, dim=1)
434
+ if sample_posterior and not self.latent_denoising_skl:
435
+ std_skl = torch.exp(0.5 * logvar_skl)
436
+ z_skl = mean_skl + std_skl * torch.randn_like(std_skl)
437
+ else:
438
+ z_skl = mean_skl
439
+
440
+ if self.latent_denoising:
441
+ mean = mean.detach()
442
+ if self.latent_denoising_skl:
443
+ mean_skl = mean_skl.detach()
444
+
445
+ if return_raw:
446
+ return z, mean, logvar, z_skl, mean_skl, logvar_skl
447
+ return z, z_skl
448
+
449
+
450
+ class AniGenSparseStructureDecoder(nn.Module):
451
+ """
452
+ Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
453
+
454
+ Args:
455
+ out_channels (int): Channels of the output.
456
+ latent_channels (int): Channels of the latent representation.
457
+ num_res_blocks (int): Number of residual blocks at each resolution.
458
+ channels (List[int]): Channels of the decoder blocks.
459
+ num_res_blocks_middle (int): Number of residual blocks in the middle.
460
+ norm_type (Literal["group", "layer"]): Type of normalization layer.
461
+ use_fp16 (bool): Whether to use FP16.
462
+ """
463
+ def __init__(
464
+ self,
465
+ out_channels: int,
466
+ out_channels_skl: int,
467
+ latent_channels: int,
468
+ latent_channels_skl: int,
469
+ num_res_blocks: int,
470
+ channels: List[int],
471
+ num_res_blocks_middle: int = 2,
472
+ norm_type: Literal["group", "layer"] = "layer",
473
+ use_fp16: bool = False,
474
+ encode_global: bool = False,
475
+ global_token_num: int = 1024,
476
+ encode_global_skl: bool = True,
477
+ global_token_num_skl: int = 1024,
478
+ use_pretrain_branch: bool = True,
479
+ freeze_pretrain_branch: bool = True,
480
+ modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"],
481
+ normalize_z: bool = False,
482
+ normalize_z_skl: bool = True,
483
+ normalize_scale: float = 1.0,
484
+ ):
485
+ super().__init__()
486
+ self.out_channels = out_channels
487
+ self.out_channels_skl = out_channels_skl
488
+ self.latent_channels = latent_channels
489
+ self.latent_channels_skl = latent_channels_skl
490
+ self.num_res_blocks = num_res_blocks
491
+ self.channels = channels
492
+ self.num_res_blocks_middle = num_res_blocks_middle
493
+ self.norm_type = norm_type
494
+ self.use_fp16 = use_fp16
495
+ self.dtype = torch.float16 if use_fp16 else torch.float32
496
+ self.encode_global = encode_global
497
+ self.global_token_num = global_token_num
498
+ self.encode_global_skl = encode_global_skl
499
+ self.global_token_num_skl = global_token_num_skl
500
+ self.use_pretrain_branch = use_pretrain_branch
501
+ self.freeze_pretrain_branch = freeze_pretrain_branch
502
+ self.normalize_z = normalize_z
503
+ self.normalize_z_skl = normalize_z_skl
504
+ self.normalize_scale = normalize_scale
505
+
506
+ if self.encode_global:
507
+ # Initial Grids and PE
508
+ upsample_factor = 2 ** (len(channels) - 1)
509
+ self.base_size_ss = 64 // upsample_factor
510
+ self.init_grids_ss = nn.Parameter(torch.zeros(1, channels[0], self.base_size_ss**3).permute(0, 2, 1).contiguous().clone()) # 1, N, C
511
+ pos_embedder = AbsolutePositionEmbedder(channels[0], 3)
512
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij')
513
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
514
+ grid_pos_emb = pos_embedder(coords)
515
+ self.register_buffer("grid_pos_emb_ss", grid_pos_emb)
516
+ # Tokens PE
517
+ pos_embedder = AbsolutePositionEmbedder(channels[0], 1)
518
+ coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1)
519
+ tokens_pos_emb = pos_embedder(coords)
520
+ self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb)
521
+ # Token projection layer
522
+ self.token_proj_ss = nn.Linear(channels[0]*2, channels[0])
523
+
524
+ # Input layers
525
+ self.input_layer = nn.ModuleList(
526
+ [TransformerBlock(
527
+ channels=channels[0] if _ != 0 else latent_channels + channels[0],
528
+ out_channels=channels[0],
529
+ num_heads=4 if _ == 0 else 16,
530
+ attn_mode="full",
531
+ qkv_bias=False,
532
+ ) for _ in range(4)] +
533
+ [TransformerCrossBlock(
534
+ channels=channels[0],
535
+ ctx_channels=channels[0],
536
+ out_channels=channels[0],
537
+ num_heads=16,
538
+ attn_mode="full",
539
+ qkv_bias=False,
540
+ x_is_query=False)]
541
+ )
542
+ else:
543
+ self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
544
+
545
+ if self.encode_global_skl:
546
+ # Initial Grids and PE
547
+ upsample_factor = 2 ** (len(channels) - 1)
548
+ self.base_size = 64 // upsample_factor
549
+ self.init_grids = nn.Parameter(torch.zeros(1, channels[0], self.base_size**3).permute(0, 2, 1).contiguous().clone()) # 1, N, C
550
+ pos_embedder = AbsolutePositionEmbedder(channels[0], 3)
551
+ coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij')
552
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3)
553
+ grid_pos_emb = pos_embedder(coords)
554
+ self.register_buffer("grid_pos_emb", grid_pos_emb)
555
+ # Tokens PE
556
+ pos_embedder = AbsolutePositionEmbedder(channels[0], 1)
557
+ coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1)
558
+ tokens_pos_emb = pos_embedder(coords)
559
+ self.register_buffer('tokens_pos_emb', tokens_pos_emb)
560
+ # Token projection layer
561
+ self.token_proj = nn.Linear(channels[0]*2, channels[0])
562
+
563
+ # Input layers
564
+ self.input_layer_skl = nn.ModuleList(
565
+ [TransformerBlock(
566
+ channels=channels[0] if _ != 0 else latent_channels_skl + channels[0],
567
+ out_channels=channels[0],
568
+ num_heads=4 if _ == 0 else 16,
569
+ attn_mode="full",
570
+ qkv_bias=False,
571
+ ) for _ in range(4)] +
572
+ [TransformerCrossBlock(
573
+ channels=channels[0],
574
+ ctx_channels=channels[0],
575
+ out_channels=channels[0],
576
+ num_heads=16,
577
+ attn_mode="full",
578
+ qkv_bias=False,
579
+ x_is_query=False)]
580
+ )
581
+ else:
582
+ self.input_layer_skl = nn.Conv3d(latent_channels_skl, channels[0], 3, padding=1)
583
+
584
+ self.middle_block = nn.Sequential(*[
585
+ ResBlock3d(channels[0], channels[0])
586
+ for _ in range(num_res_blocks_middle)
587
+ ])
588
+ self.middle_block_skl = nn.Sequential(*[
589
+ ResBlock3d(channels[0] if _ == 0 else channels[0], channels[0])
590
+ for _ in range(num_res_blocks_middle)
591
+ ])
592
+
593
+ self.blocks = nn.ModuleList([])
594
+ self.blocks_skl = nn.ModuleList([])
595
+ for i, ch in enumerate(channels):
596
+ self.blocks.extend([
597
+ ResBlock3d(ch, ch)
598
+ for _ in range(num_res_blocks)
599
+ ])
600
+ if i < len(channels) - 1:
601
+ self.blocks.append(
602
+ UpsampleBlock3d(ch, channels[i+1])
603
+ )
604
+ self.blocks_skl.extend([
605
+ ResBlock3d(ch, ch)
606
+ for _ in range(num_res_blocks)
607
+ ])
608
+ if i < len(channels) - 1:
609
+ self.blocks_skl.append(
610
+ UpsampleBlock3d(ch, channels[i+1])
611
+ )
612
+
613
+ self.out_layer = nn.Sequential(
614
+ norm_layer(norm_type, channels[-1]),
615
+ nn.SiLU(),
616
+ nn.Conv3d(channels[-1], self.out_channels, 3, padding=1)
617
+ )
618
+ self.out_layer_skl = nn.Sequential(
619
+ norm_layer(norm_type, channels[-1]),
620
+ nn.SiLU(),
621
+ nn.Conv3d(channels[-1], self.out_channels_skl, 3, padding=1)
622
+ )
623
+
624
+ self.initialize_weights()
625
+ if use_fp16:
626
+ self.convert_to_fp16()
627
+
628
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
629
+ # Freeze: self.input_layer, self.blocks, self.middle_block, self.out_layer
630
+ for module in modules_to_freeze:
631
+ if hasattr(self, module):
632
+ mod = getattr(self, module)
633
+ if isinstance(mod, nn.ModuleList):
634
+ for m in mod:
635
+ for param in m.parameters():
636
+ param.requires_grad = False
637
+ else:
638
+ for param in mod.parameters():
639
+ param.requires_grad = False
640
+
641
+ @property
642
+ def device(self) -> torch.device:
643
+ """
644
+ Return the device of the model.
645
+ """
646
+ return next(self.parameters()).device
647
+
648
+ def convert_to_fp16(self) -> None:
649
+ """
650
+ Convert the torso of the model to float16.
651
+ """
652
+ self.use_fp16 = True
653
+ self.dtype = torch.float16
654
+ self.blocks.apply(convert_module_to_f16)
655
+ self.middle_block.apply(convert_module_to_f16)
656
+ self.blocks_skl.apply(convert_module_to_f16)
657
+ self.middle_block_skl.apply(convert_module_to_f16)
658
+ if self.encode_global_skl:
659
+ self.token_proj.apply(convert_module_to_f16)
660
+ self.input_layer_skl.apply(convert_module_to_f16)
661
+ if self.encode_global:
662
+ self.token_proj_ss.apply(convert_module_to_f16)
663
+ self.input_layer.apply(convert_module_to_f16)
664
+
665
+ def convert_to_fp32(self) -> None:
666
+ """
667
+ Convert the torso of the model to float32.
668
+ """
669
+ self.use_fp16 = False
670
+ self.dtype = torch.float32
671
+ self.blocks.apply(convert_module_to_f32)
672
+ self.middle_block.apply(convert_module_to_f32)
673
+ self.blocks_skl.apply(convert_module_to_f32)
674
+ self.middle_block_skl.apply(convert_module_to_f32)
675
+ if self.encode_global_skl:
676
+ self.token_proj.apply(convert_module_to_f32)
677
+ self.input_layer_skl.apply(convert_module_to_f32)
678
+ if self.encode_global:
679
+ self.token_proj_ss.apply(convert_module_to_f32)
680
+ self.input_layer.apply(convert_module_to_f32)
681
+
682
+ def initialize_weights(self) -> None:
683
+ # Initialize transformer layers:
684
+ def _basic_init(module):
685
+ if isinstance(module, nn.Linear):
686
+ torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear')
687
+ if module.bias is not None:
688
+ nn.init.constant_(module.bias, 0)
689
+ self.apply(_basic_init)
690
+
691
+ def forward(self, x: torch.Tensor, x_skl: torch.Tensor) -> torch.Tensor:
692
+ h = F.normalize(x, dim=1) * self.normalize_scale if self.normalize_z else x
693
+ h_skl = F.normalize(x_skl, dim=1) * self.normalize_scale if self.normalize_z_skl else x_skl
694
+ if self.encode_global:
695
+ B, _, _ = h.shape
696
+ h = torch.cat([h, self.tokens_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(self.dtype)
697
+ h = h.type(self.dtype)
698
+ for layer in self.input_layer[:-1]:
699
+ h = layer(h)
700
+ init_grids = torch.cat([self.init_grids_ss, self.grid_pos_emb_ss[None].expand_as(self.init_grids_ss)], dim=-1).type(self.dtype)
701
+ grids = self.token_proj_ss(init_grids.expand(B, -1, -1))
702
+ h = self.input_layer[-1](grids, h) # B, N, C
703
+ h = h.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size)
704
+ else:
705
+ h = self.input_layer(h)
706
+ h = h.type(self.dtype)
707
+ if self.encode_global_skl:
708
+ B, _, _ = h_skl.shape
709
+ h_skl = torch.cat([h_skl, self.tokens_pos_emb[None].expand(B, -1, -1)], dim=-1).type(self.dtype)
710
+ h_skl = h_skl.type(self.dtype)
711
+ for layer in self.input_layer_skl[:-1]:
712
+ h_skl = layer(h_skl)
713
+ init_grids = torch.cat([self.init_grids, self.grid_pos_emb[None].expand_as(self.init_grids)], dim=-1).type(self.dtype)
714
+ grids = self.token_proj(init_grids.expand(B, -1, -1))
715
+ h_skl = self.input_layer_skl[-1](grids, h_skl) # B, N, C
716
+ h_skl = h_skl.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size)
717
+ else:
718
+ h_skl = self.input_layer_skl(h_skl)
719
+ h_skl = h_skl.type(self.dtype)
720
+ h_skl = self.middle_block_skl(h_skl)
721
+ h = self.middle_block(h)
722
+ for block, block_skl in zip(self.blocks, self.blocks_skl):
723
+ h_skl = block_skl(h_skl)
724
+ h = block(h)
725
+ h = h.type(x.dtype)
726
+ h = self.out_layer(h)
727
+ h_skl = h_skl.type(x.dtype)
728
+ h_skl = self.out_layer_skl(h_skl)
729
+ return h, h_skl
anigen/models/anigen_structured_latent_flow.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+
8
+ from anigen.modules.transformer import blocks
9
+ from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
10
+ from ..modules.transformer import AbsolutePositionEmbedder
11
+ from ..modules.norm import LayerNorm32
12
+ from ..modules import sparse as sp
13
+ from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
14
+ from .sparse_elastic_mixin import SparseTransformerElasticMixin
15
+
16
+
17
+ class TimestepEmbedder(nn.Module):
18
+ """
19
+ Embeds scalar timesteps into vector representations.
20
+ """
21
+ def __init__(self, hidden_size, frequency_embedding_size=256):
22
+ super().__init__()
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
25
+ nn.SiLU(),
26
+ nn.Linear(hidden_size, hidden_size, bias=True),
27
+ )
28
+ self.frequency_embedding_size = frequency_embedding_size
29
+
30
+ @staticmethod
31
+ def timestep_embedding(t, dim, max_period=10000):
32
+ """
33
+ Create sinusoidal timestep embeddings.
34
+
35
+ Args:
36
+ t: a 1-D Tensor of N indices, one per batch element.
37
+ These may be fractional.
38
+ dim: the dimension of the output.
39
+ max_period: controls the minimum frequency of the embeddings.
40
+
41
+ Returns:
42
+ an (N, D) Tensor of positional embeddings.
43
+ """
44
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
45
+ half = dim // 2
46
+ freqs = torch.exp(
47
+ -np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
48
+ ).to(device=t.device)
49
+ args = t[:, None].float() * freqs[None]
50
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
+ if dim % 2:
52
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
53
+ return embedding
54
+
55
+ def forward(self, t):
56
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
57
+ t_emb = self.mlp(t_freq)
58
+ return t_emb
59
+
60
+
61
+ class SparseResBlock3d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ channels: int,
65
+ emb_channels: int,
66
+ out_channels: Optional[int] = None,
67
+ downsample: bool = False,
68
+ upsample: bool = False,
69
+ ):
70
+ super().__init__()
71
+ self.channels = channels
72
+ self.emb_channels = emb_channels
73
+ self.out_channels = out_channels or channels
74
+ self.downsample = downsample
75
+ self.upsample = upsample
76
+
77
+ assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
78
+
79
+ self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
80
+ self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
81
+ self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
82
+ self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
83
+ self.emb_layers = nn.Sequential(
84
+ nn.SiLU(),
85
+ nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
86
+ )
87
+ self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
88
+ self.updown = None
89
+ if self.downsample:
90
+ self.updown = sp.SparseDownsample(2)
91
+ elif self.upsample:
92
+ self.updown = sp.SparseUpsample(2)
93
+
94
+ def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
95
+ if self.updown is not None:
96
+ x = self.updown(x)
97
+ return x
98
+
99
+ def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
100
+ emb_out = self.emb_layers(emb).type(x.dtype)
101
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
102
+
103
+ x = self._updown(x)
104
+ h = x.replace(self.norm1(x.feats))
105
+ h = h.replace(F.silu(h.feats))
106
+ h = self.conv1(h)
107
+ h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
108
+ h = h.replace(F.silu(h.feats))
109
+ h = self.conv2(h)
110
+ h = h + self.skip_connection(x)
111
+
112
+ return h
113
+
114
+
115
+ class AniGenSLatFlowModel(nn.Module):
116
+ def __init__(
117
+ self,
118
+ resolution: int,
119
+ in_channels: int,
120
+ in_channels_vert_skin: int,
121
+ in_channels_skl: int,
122
+ model_channels: int,
123
+ model_channels_vert_skin: int,
124
+ model_channels_skl: int,
125
+ cond_channels: int,
126
+ out_channels: int,
127
+ out_channels_vert_skin: int,
128
+ out_channels_skl: int,
129
+ num_blocks: int,
130
+ num_heads: Optional[int] = None,
131
+ num_head_channels: Optional[int] = 64,
132
+ num_heads_vert_skin: Optional[int] = None,
133
+ num_head_channels_vert_skin: Optional[int] = 64,
134
+ num_heads_skl: Optional[int] = None,
135
+ num_head_channels_skl: Optional[int] = 64,
136
+ mlp_ratio: float = 4,
137
+ patch_size: int = 2,
138
+ num_io_res_blocks: int = 2,
139
+ num_io_res_blocks_vert_skin: int = 2,
140
+ num_io_res_blocks_skl: int = 2,
141
+ io_block_channels: List[int] = None,
142
+ io_block_channels_vert_skin: List[int] = None,
143
+ io_block_channels_skl: List[int] = None,
144
+ pe_mode: Literal["ape", "rope"] = "ape",
145
+ use_fp16: bool = False,
146
+ use_checkpoint: bool = False,
147
+ use_skip_connection: bool = True,
148
+ share_mod: bool = False,
149
+ qk_rms_norm: bool = False,
150
+ qk_rms_norm_cross: bool = False,
151
+ use_pretrain_branch: bool = True,
152
+ freeze_pretrain_branch: bool = True,
153
+ modules_to_freeze: Optional[List[str]] = ['blocks', 'input_blocks','input_layer', 'out_blocks', 'out_layer', 't_embedder'],
154
+ predict_x0: bool = False,
155
+ t_eps: float = 5e-2,
156
+ t_scale: float = 1e3,
157
+ use_joint_num_cond: bool = False,
158
+ joint_num_max: int = 60,
159
+ joint_num_fourier_bands: int = 6,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.pretrain_class_name = ["AniGenSlatFlowImage"]
164
+
165
+ self.resolution = resolution
166
+ self.in_channels = in_channels
167
+ self.in_channels_vert_skin = in_channels_vert_skin
168
+ self.in_channels_skl = in_channels_skl
169
+ self.model_channels = model_channels
170
+ self.model_channels_vert_skin = model_channels_vert_skin
171
+ self.model_channels_skl = model_channels_skl
172
+ self.cond_channels = cond_channels
173
+ self.out_channels = out_channels
174
+ self.out_channels_vert_skin = out_channels_vert_skin
175
+ self.out_channels_skl = out_channels_skl
176
+ self.num_blocks = num_blocks
177
+ self.num_heads = num_heads or model_channels // num_head_channels
178
+ self.num_heads_vert_skin = num_heads_vert_skin or model_channels_vert_skin // num_head_channels_vert_skin
179
+ self.num_heads_skl = num_heads_skl or model_channels_skl // num_head_channels_skl
180
+ self.mlp_ratio = mlp_ratio
181
+ self.patch_size = patch_size
182
+ self.num_io_res_blocks = num_io_res_blocks
183
+ self.num_io_res_blocks_vert_skin = num_io_res_blocks_vert_skin
184
+ self.num_io_res_blocks_skl = num_io_res_blocks_skl
185
+ self.io_block_channels = io_block_channels
186
+ self.io_block_channels_vert_skin = io_block_channels_vert_skin
187
+ self.io_block_channels_skl = io_block_channels_skl
188
+ self.pe_mode = pe_mode
189
+ self.use_fp16 = use_fp16
190
+ self.use_checkpoint = use_checkpoint
191
+ self.use_skip_connection = use_skip_connection
192
+ self.share_mod = share_mod
193
+ self.qk_rms_norm = qk_rms_norm
194
+ self.qk_rms_norm_cross = qk_rms_norm_cross
195
+ self.dtype = torch.float16 if use_fp16 else torch.float32
196
+ self.predict_x0 = predict_x0
197
+ self.t_eps = t_eps
198
+ self.t_scale = t_scale
199
+ self.use_joint_num_cond = use_joint_num_cond
200
+ self.joint_num_max = joint_num_max
201
+ self.joint_num_fourier_bands = joint_num_fourier_bands
202
+
203
+ if self.io_block_channels is not None:
204
+ assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
205
+ assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
206
+
207
+ self.t_embedder = TimestepEmbedder(model_channels)
208
+ self.t_embedder_vert_skin = TimestepEmbedder(model_channels_vert_skin)
209
+ self.t_embedder_skl = TimestepEmbedder(model_channels_skl)
210
+
211
+ if self.use_joint_num_cond:
212
+ # Joint-number conditioning (applied to skin + skeleton branches).
213
+ # If joints_num is missing/<=0, use learnable unconditional embeddings.
214
+ self.joint_num_embedder_vert_skin = nn.Sequential(
215
+ nn.Linear(2 * joint_num_fourier_bands, model_channels_vert_skin, bias=True),
216
+ nn.SiLU(),
217
+ nn.Linear(model_channels_vert_skin, model_channels_vert_skin, bias=True),
218
+ )
219
+ self.joint_num_embedder_skl = nn.Sequential(
220
+ nn.Linear(2 * joint_num_fourier_bands, model_channels_skl, bias=True),
221
+ nn.SiLU(),
222
+ nn.Linear(model_channels_skl, model_channels_skl, bias=True),
223
+ )
224
+ self.joint_num_uncond_vert_skin = nn.Parameter(torch.zeros(model_channels_vert_skin))
225
+ self.joint_num_uncond_skl = nn.Parameter(torch.zeros(model_channels_skl))
226
+ if share_mod:
227
+ self.adaLN_modulation = nn.Sequential(
228
+ nn.SiLU(),
229
+ nn.Linear(model_channels, 6 * model_channels, bias=True)
230
+ )
231
+ self.adaLN_modulation_vert_skin = nn.Sequential(
232
+ nn.SiLU(),
233
+ nn.Linear(model_channels_vert_skin, 6 * model_channels_vert_skin, bias=True)
234
+ )
235
+ self.adaLN_modulation_skl = nn.Sequential(
236
+ nn.SiLU(),
237
+ nn.Linear(model_channels_skl, 6 * model_channels_skl, bias=True)
238
+ )
239
+
240
+ if pe_mode == "ape":
241
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
242
+ self.pos_embedder_vert_skin = AbsolutePositionEmbedder(model_channels_vert_skin)
243
+ self.pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl)
244
+
245
+ # Causuality in conditioning:
246
+ # Geometry <- Conditioned Image (Cross Attention)
247
+ # Skinning <- Geometry (Adapter Layer) + Skeleton (Cross Attention)
248
+ # Skeleton <- Skinning (Cross Attention)
249
+ causial_cond_channels_dict = {'': cond_channels, '_vert_skin': self.model_channels_skl, '_skl': self.model_channels_vert_skin}
250
+
251
+ for postfix in ['', '_vert_skin', '_skl']:
252
+ # Input blocks
253
+ setattr(self, f'input_layer{postfix}', sp.SparseLinear(
254
+ getattr(self, f'in_channels{postfix}'),
255
+ getattr(self, f'model_channels{postfix}') if getattr(self, f'io_block_channels{postfix}') is None else getattr(self, f'io_block_channels{postfix}')[0]
256
+ ))
257
+
258
+ setattr(self, f'input_blocks{postfix}', nn.ModuleList([]))
259
+ io_block_channels = getattr(self, f'io_block_channels{postfix}')
260
+ model_channels = getattr(self, f'model_channels{postfix}')
261
+ num_io_res_blocks = getattr(self, f'num_io_res_blocks{postfix}')
262
+ if io_block_channels is not None:
263
+ for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
264
+ getattr(self, f'input_blocks{postfix}').extend([
265
+ SparseResBlock3d(
266
+ chs,
267
+ model_channels,
268
+ out_channels=chs,
269
+ )
270
+ for _ in range(num_io_res_blocks-1)
271
+ ])
272
+ getattr(self, f'input_blocks{postfix}').append(
273
+ SparseResBlock3d(
274
+ chs,
275
+ model_channels,
276
+ out_channels=next_chs,
277
+ downsample=True,
278
+ )
279
+ )
280
+
281
+ # Transformer blocks
282
+ cond_channels_block = causial_cond_channels_dict[postfix]
283
+ setattr(self, f'blocks{postfix}', nn.ModuleList([
284
+ ModulatedSparseTransformerCrossBlock(
285
+ getattr(self, f'model_channels{postfix}'),
286
+ cond_channels_block,
287
+ num_heads=getattr(self, f'num_heads{postfix}'),
288
+ mlp_ratio=self.mlp_ratio,
289
+ attn_mode='full',
290
+ use_checkpoint=self.use_checkpoint,
291
+ use_rope=(pe_mode == "rope"),
292
+ share_mod=self.share_mod,
293
+ qk_rms_norm=self.qk_rms_norm,
294
+ qk_rms_norm_cross=self.qk_rms_norm_cross,
295
+ norm_for_context=True,
296
+ )
297
+ for _ in range(num_blocks)
298
+ ]))
299
+
300
+ # Output blocks
301
+ setattr(self, f'out_blocks{postfix}', nn.ModuleList([]))
302
+ if io_block_channels is not None:
303
+ for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
304
+ getattr(self, f'out_blocks{postfix}').append(
305
+ SparseResBlock3d(
306
+ prev_chs * 2 if self.use_skip_connection else prev_chs,
307
+ model_channels,
308
+ out_channels=chs,
309
+ upsample=True,
310
+ )
311
+ )
312
+ getattr(self, f'out_blocks{postfix}').extend([
313
+ SparseResBlock3d(
314
+ chs * 2 if self.use_skip_connection else chs,
315
+ model_channels,
316
+ out_channels=chs,
317
+ )
318
+ for _ in range(num_io_res_blocks-1)
319
+ ])
320
+ setattr(self, f'out_layer{postfix}', sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], getattr(self, f'out_channels{postfix}')))
321
+
322
+ self.adapter_geo_to_skin = nn.ModuleList([
323
+ sp.SparseLinear(self.model_channels, self.model_channels_vert_skin) for _ in range(num_blocks)
324
+ ])
325
+
326
+ self.initialize_weights()
327
+ if use_fp16:
328
+ self.convert_to_fp16()
329
+
330
+ self.use_pretrain_branch = use_pretrain_branch
331
+ self.freeze_pretrain_branch = freeze_pretrain_branch
332
+ # self.is_geometry_branch_frozen = self.use_pretrain_branch and self.freeze_pretrain_branch and all([module in modules_to_freeze for module in ['blocks', 'input_blocks','input_layer', 'out_blocks', 'out_layer', 't_embedder']])
333
+
334
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
335
+ for module in modules_to_freeze:
336
+ if hasattr(self, module):
337
+ mod = getattr(self, module)
338
+ if isinstance(mod, nn.ModuleList):
339
+ for m in mod:
340
+ for param in m.parameters():
341
+ param.requires_grad = False
342
+ else:
343
+ for param in mod.parameters():
344
+ param.requires_grad = False
345
+
346
+ @property
347
+ def device(self) -> torch.device:
348
+ """
349
+ Return the device of the model.
350
+ """
351
+ return next(self.parameters()).device
352
+
353
+ def convert_to_fp16(self) -> None:
354
+ """
355
+ Convert the torso of the model to float16.
356
+ """
357
+ for postfix in ['', '_vert_skin', '_skl']:
358
+ getattr(self, f'input_blocks{postfix}').apply(convert_module_to_f16)
359
+ getattr(self, f'blocks{postfix}').apply(convert_module_to_f16)
360
+ getattr(self, f'out_blocks{postfix}').apply(convert_module_to_f16)
361
+ self.adapter_geo_to_skin.apply(convert_module_to_f16)
362
+ if self.use_joint_num_cond:
363
+ self.joint_num_embedder_vert_skin.apply(convert_module_to_f16)
364
+ self.joint_num_embedder_skl.apply(convert_module_to_f16)
365
+ self.joint_num_uncond_vert_skin.data = self.joint_num_uncond_vert_skin.data.half()
366
+ self.joint_num_uncond_skl.data = self.joint_num_uncond_skl.data.half()
367
+
368
+ def convert_to_fp32(self) -> None:
369
+ """
370
+ Convert the torso of the model to float32.
371
+ """
372
+ for postfix in ['', '_vert_skin', '_skl']:
373
+ getattr(self, f'input_blocks{postfix}').apply(convert_module_to_f32)
374
+ getattr(self, f'blocks{postfix}').apply(convert_module_to_f32)
375
+ getattr(self, f'out_blocks{postfix}').apply(convert_module_to_f32)
376
+ self.adapter_geo_to_skin.apply(convert_module_to_f32)
377
+ if self.use_joint_num_cond:
378
+ self.joint_num_embedder_vert_skin.apply(convert_module_to_f32)
379
+ self.joint_num_embedder_skl.apply(convert_module_to_f32)
380
+ self.joint_num_uncond_vert_skin.data = self.joint_num_uncond_vert_skin.data.float()
381
+ self.joint_num_uncond_skl.data = self.joint_num_uncond_skl.data.float()
382
+
383
+ def initialize_weights(self) -> None:
384
+ # Initialize transformer layers:
385
+ def _basic_init(module):
386
+ if isinstance(module, nn.Linear):
387
+ torch.nn.init.xavier_uniform_(module.weight)
388
+ if module.bias is not None:
389
+ nn.init.constant_(module.bias, 0)
390
+ self.apply(_basic_init)
391
+
392
+ for postfix in ['', '_vert_skin', '_skl']:
393
+ nn.init.normal_(getattr(self, f't_embedder{postfix}').mlp[0].weight, std=0.02)
394
+ nn.init.normal_(getattr(self, f't_embedder{postfix}').mlp[2].weight, std=0.02)
395
+ if self.share_mod:
396
+ nn.init.constant_(getattr(self, f'adaLN_modulation{postfix}')[-1].weight, 0)
397
+ nn.init.constant_(getattr(self, f'adaLN_modulation{postfix}')[-1].bias, 0)
398
+ else:
399
+ for block in getattr(self, f'blocks{postfix}'):
400
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
401
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
402
+ nn.init.constant_(getattr(self, f'out_layer{postfix}').weight, 0)
403
+ nn.init.constant_(getattr(self, f'out_layer{postfix}').bias, 0)
404
+
405
+ for layer in self.adapter_geo_to_skin:
406
+ nn.init.constant_(layer.weight, 0)
407
+ nn.init.constant_(layer.bias, 0)
408
+
409
+ if self.use_joint_num_cond:
410
+ # Joint-number conditioning layers
411
+ for emb in [self.joint_num_embedder_vert_skin, self.joint_num_embedder_skl]:
412
+ for m in emb.modules():
413
+ if isinstance(m, nn.Linear):
414
+ torch.nn.init.xavier_uniform_(m.weight)
415
+ if m.bias is not None:
416
+ nn.init.constant_(m.bias, 0)
417
+
418
+ def _fourier_encode_joint_num(self, joints_num: torch.Tensor) -> torch.Tensor:
419
+ """Fourier features for joints_num in [0, joint_num_max]."""
420
+ # Keep dtype consistent with model (e.g., fp16) to avoid Linear dtype mismatch.
421
+ dtype = getattr(self, 'dtype', torch.float32)
422
+ x = (joints_num.to(dtype=dtype) / float(self.joint_num_max)).clamp(0.0, 1.0)
423
+ x = x[:, None]
424
+ freqs = (2.0 ** torch.arange(self.joint_num_fourier_bands, device=x.device, dtype=x.dtype)) * math.pi
425
+ angles = x * freqs[None, :]
426
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
427
+
428
+ def _get_joint_num_emb(self, joints_num: Optional[torch.Tensor], batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
429
+ """Return (emb_vert_skin, emb_skl), shape [B, C_*]."""
430
+ if joints_num is None:
431
+ joints_num = torch.zeros(batch_size, device=device)
432
+ elif not torch.is_tensor(joints_num):
433
+ joints_num = torch.tensor(joints_num, device=device)
434
+ joints_num = joints_num.to(device=device)
435
+ if joints_num.dim() == 0:
436
+ joints_num = joints_num[None].expand(batch_size)
437
+ joints_num = joints_num.reshape(batch_size)
438
+
439
+ mask_dtype = getattr(self, 'dtype', torch.float32)
440
+ uncond_mask = (joints_num <= 0).to(dtype=mask_dtype, device=device)[:, None]
441
+ joints_num = joints_num.clamp(min=0, max=self.joint_num_max)
442
+
443
+ fourier = self._fourier_encode_joint_num(joints_num)
444
+ emb_vs_cond = self.joint_num_embedder_vert_skin(fourier)
445
+ emb_skl_cond = self.joint_num_embedder_skl(fourier)
446
+
447
+ emb_vs_uncond = self.joint_num_uncond_vert_skin[None].expand(batch_size, -1)
448
+ emb_skl_uncond = self.joint_num_uncond_skl[None].expand(batch_size, -1)
449
+
450
+ # Blend: uncond_mask==1 -> unconditional, uncond_mask==0 -> conditional.
451
+ emb_vs = emb_vs_cond * (1.0 - uncond_mask) + emb_vs_uncond * uncond_mask
452
+ emb_skl = emb_skl_cond * (1.0 - uncond_mask) + emb_skl_uncond * uncond_mask
453
+ return emb_vs, emb_skl
454
+
455
+ def forward_stage(
456
+ self,
457
+ x: sp.SparseTensor,
458
+ t: torch.Tensor,
459
+ postfix,
460
+ stage,
461
+ cond_emb: Optional[torch.Tensor] = None,
462
+ t_emb=None,
463
+ skips=None,
464
+ original_dtype=None,
465
+ ) -> sp.SparseTensor:
466
+ input_layer = getattr(self, f'input_layer{postfix}')
467
+ t_embedder = getattr(self, f't_embedder{postfix}')
468
+ input_blocks = getattr(self, f'input_blocks{postfix}')
469
+ pos_embedder = getattr(self, f'pos_embedder{postfix}')
470
+ out_blocks = getattr(self, f'out_blocks{postfix}')
471
+ out_layer = getattr(self, f'out_layer{postfix}')
472
+ adaLN_modulation = getattr(self, f'adaLN_modulation{postfix}') if self.share_mod else None
473
+
474
+ if stage == 'in':
475
+ h = input_layer(x).type(self.dtype)
476
+ t_emb = t_embedder(t)
477
+ if cond_emb is not None:
478
+ t_emb = t_emb + cond_emb
479
+ t_emb = t_emb.type(self.dtype)
480
+ t_mod = adaLN_modulation(t_emb).type(self.dtype) if self.share_mod else t_emb
481
+ skips = []
482
+ # pack with input blocks
483
+ for block in input_blocks:
484
+ h = block(h, t_emb)
485
+ skips.append(h.feats)
486
+ if self.pe_mode == "ape":
487
+ h = h + pos_embedder(h.coords[:, 1:]).type(self.dtype)
488
+ return h, t_emb, t_mod, skips
489
+ elif stage == 'out':
490
+ h = x
491
+ # unpack with output blocks
492
+ for block, skip in zip(out_blocks, reversed(skips)):
493
+ if self.use_skip_connection:
494
+ h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
495
+ else:
496
+ h = block(h, t_emb)
497
+ h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
498
+ h = out_layer(h.type(original_dtype))
499
+ return h
500
+ else:
501
+ raise ValueError(f"Unknown stage: {stage}")
502
+
503
+ def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor, joints_num: Optional[torch.Tensor] = None, **kwargs) -> sp.SparseTensor:
504
+ cond = cond.type(self.dtype)
505
+ feats, feats_vert_skin = x.feats[:, :self.in_channels], x.feats[:, self.in_channels:]
506
+ x, x_vert_skin = x.replace(feats), x.replace(feats_vert_skin)
507
+ if self.predict_x0:
508
+ xt_feats_skin, xt_feats_skl = feats_vert_skin.clone(), x_skl.feats.clone()
509
+
510
+ joint_emb_vs, joint_emb_skl = None, None
511
+ if self.use_joint_num_cond:
512
+ # joint-number conditioning for skin + skeleton
513
+ joint_emb_vs, joint_emb_skl = self._get_joint_num_emb(joints_num, x.shape[0], x.device)
514
+ joint_emb_vs = joint_emb_vs.type(self.dtype)
515
+ joint_emb_skl = joint_emb_skl.type(self.dtype)
516
+
517
+ in_dicts = {'': x, '_vert_skin': x_vert_skin, '_skl': x_skl}
518
+ cond_emb_dicts = {'': None, '_vert_skin': joint_emb_vs, '_skl': joint_emb_skl}
519
+ postfix_keys = list(in_dicts.keys())
520
+ for postfix in postfix_keys:
521
+ cond_emb = cond_emb_dicts[postfix]
522
+ in_dicts[postfix], in_dicts[f't_emb{postfix}'], in_dicts[f't_mod{postfix}'], in_dicts[f'skips{postfix}'] = self.forward_stage(in_dicts[postfix], t, postfix, stage='in', cond_emb=cond_emb)
523
+ for block, block_skin, block_skl, adapter in zip(self.blocks, self.blocks_vert_skin, self.blocks_skl, self.adapter_geo_to_skin):
524
+ h, h_skin, h_skl = in_dicts[''], in_dicts['_vert_skin'], in_dicts['_skl']
525
+ f = block(h, in_dicts['t_mod'], cond)
526
+ f_skin = block_skin(h_skin, in_dicts['t_mod_vert_skin'], h_skl) + adapter(h)
527
+ f_skl = block_skl(h_skl, in_dicts['t_mod_skl'], h_skin)
528
+ in_dicts[''], in_dicts['_vert_skin'], in_dicts['_skl'] = f, f_skin, f_skl
529
+ for postfix in postfix_keys:
530
+ in_dicts[postfix] = self.forward_stage(
531
+ in_dicts[postfix],
532
+ t,
533
+ postfix,
534
+ stage='out',
535
+ t_emb=in_dicts[f't_emb{postfix}'],
536
+ skips=in_dicts[f'skips{postfix}'],
537
+ original_dtype=x.dtype,
538
+ )
539
+ if self.predict_x0:
540
+ t_normalized = t / self.t_scale
541
+ factor = (1 / t_normalized.clamp_min(self.t_eps))[:, None]
542
+ in_dicts['_vert_skin'] = in_dicts['_vert_skin'].replace((in_dicts['_vert_skin'].feats - xt_feats_skin) * factor[in_dicts['_vert_skin'].coords[:, 0]])
543
+ in_dicts['_skl'] = in_dicts['_skl'].replace((in_dicts['_skl'].feats - xt_feats_skl) * factor[in_dicts['_skl'].coords[:, 0]])
544
+ x_out = x.replace(torch.cat([in_dicts[''].feats, in_dicts['_vert_skin'].feats], dim=1))
545
+ x_skl_out = x_skl.replace(in_dicts['_skl'].feats)
546
+ return x_out, x_skl_out
547
+
548
+ class AniGenElasticSLatFlowModel(SparseTransformerElasticMixin, AniGenSLatFlowModel):
549
+ """
550
+ SLat Flow Model with elastic memory management.
551
+ Used for training with low VRAM.
552
+ """
553
+ pass
anigen/models/sparse_elastic_mixin.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import *
3
+ import math
4
+ from ..modules import sparse as sp
5
+ from ..utils.elastic_utils import ElasticModuleMixin
6
+
7
+
8
+ class SparseTransformerElasticMixin(ElasticModuleMixin):
9
+ def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
10
+ return x.feats.shape[0]
11
+
12
+ @contextmanager
13
+ def with_mem_ratio(self, mem_ratio=1.0):
14
+ if mem_ratio == 1.0:
15
+ yield 1.0
16
+ return
17
+ num_blocks = len(self.blocks)
18
+ num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
19
+ exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
20
+ for i in range(num_blocks):
21
+ self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
22
+ yield exact_mem_ratio
23
+ for i in range(num_blocks):
24
+ self.blocks[i].use_checkpoint = False
anigen/models/structured_latent_vae/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .anigen_encoder import AniGenElasticSLatEncoder
2
+ from .anigen_decoder import AniGenElasticSLatMeshDecoder
3
+ from .skin_models import SkinAutoEncoder
anigen/models/structured_latent_vae/anigen_base.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ...modules import sparse as sp
5
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
6
+ from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock
7
+ from ...modules.transformer import AbsolutePositionEmbedder, TransformerCrossBlock
8
+
9
+
10
+ class FreqPositionalEmbedder(nn.Module):
11
+ def __init__(self, in_dim, include_input=True, max_freq_log2=8, num_freqs=8, log_sampling=True, periodic_fns=None):
12
+ super().__init__()
13
+ self.in_dim = in_dim
14
+ self.out_dim = None
15
+ self.include_input = include_input
16
+ self.max_freq_log2 = max_freq_log2
17
+ self.num_freqs = num_freqs
18
+ self.log_sampling = log_sampling
19
+ self.periodic_fns = periodic_fns if periodic_fns is not None else [
20
+ torch.sin, torch.cos
21
+ ]
22
+ self.create_embedding_fn()
23
+
24
+ def create_embedding_fn(self):
25
+ embed_fns = []
26
+ d = self.in_dim
27
+ out_dim = 0
28
+ if self.include_input:
29
+ embed_fns.append(lambda x: x)
30
+ out_dim += d
31
+
32
+ max_freq = self.max_freq_log2
33
+ N_freqs = self.num_freqs
34
+
35
+ if self.log_sampling:
36
+ freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
37
+ else:
38
+ freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
39
+
40
+ for freq in freq_bands:
41
+ for p_fn in self.periodic_fns:
42
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
43
+ out_dim += d
44
+
45
+ self.embed_fns = embed_fns
46
+ self.out_dim = out_dim
47
+
48
+ def forward(self, inputs):
49
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
50
+
51
+
52
+ def block_attn_config(self, attn_mode_attr='attn_mode'):
53
+ """
54
+ Return the attention configuration of the model.
55
+ """
56
+ attn_mode = getattr(self, attn_mode_attr)
57
+ for i in range(self.num_blocks):
58
+ if attn_mode == "shift_window":
59
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
60
+ elif attn_mode == "shift_sequence":
61
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
62
+ elif attn_mode == "shift_order":
63
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
64
+ elif attn_mode == "full":
65
+ yield "full", None, None, None, None
66
+ elif attn_mode == "swin":
67
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
68
+
69
+
70
+ class AniGenSparseTransformerBase(nn.Module):
71
+ """
72
+ Sparse Transformer without output layers.
73
+ Serve as the base class for encoder and decoder.
74
+ """
75
+ def __init__(
76
+ self,
77
+ in_channels: int,
78
+ in_channels_skl: int,
79
+ in_channels_skin: int,
80
+ model_channels: int,
81
+ model_channels_skl: int,
82
+ model_channels_skin: int,
83
+ num_blocks: int,
84
+ num_heads: Optional[int] = None,
85
+ num_heads_skl: int = 8,
86
+ num_heads_skin: int = 8,
87
+ num_head_channels: Optional[int] = 64,
88
+ mlp_ratio: float = 4.0,
89
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
90
+ attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
91
+ window_size: Optional[int] = None,
92
+ pe_mode: Literal["ape", "rope"] = "ape",
93
+ use_fp16: bool = False,
94
+ use_checkpoint: bool = False,
95
+ qk_rms_norm: bool = False,
96
+
97
+ skin_cross_from_geo: bool = True,
98
+ skl_cross_from_geo: bool = True,
99
+ skin_skl_cross: bool = True,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ self.in_channels_skl = in_channels_skl
104
+ self.in_channels_skin = in_channels_skin
105
+ self.model_channels = model_channels
106
+ self.model_channels_skl = model_channels_skl
107
+ self.model_channels_skin = model_channels_skin
108
+ self.num_blocks = num_blocks
109
+ self.window_size = window_size
110
+ self.num_heads = num_heads or model_channels // num_head_channels
111
+ self.mlp_ratio = mlp_ratio
112
+ self.attn_mode = attn_mode
113
+ self.attn_mode_cross = attn_mode_cross
114
+ self.pe_mode = pe_mode
115
+ self.use_fp16 = use_fp16
116
+ self.use_checkpoint = use_checkpoint
117
+ self.qk_rms_norm = qk_rms_norm
118
+ self.dtype = torch.float16 if use_fp16 else torch.float32
119
+ self.skin_cross_from_geo = skin_cross_from_geo
120
+ self.skl_cross_from_geo = skl_cross_from_geo
121
+ self.skin_skl_cross = skin_skl_cross
122
+
123
+ if pe_mode == "ape":
124
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
125
+ self.pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl)
126
+ self.pos_embedder_skin = AbsolutePositionEmbedder(model_channels_skin)
127
+
128
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
129
+ self.input_layer_skl = sp.SparseLinear(in_channels_skl, model_channels_skl)
130
+ self.input_layer_skin = sp.SparseLinear(in_channels_skin, model_channels_skin)
131
+
132
+ self.blocks = nn.ModuleList([
133
+ SparseTransformerBlock(
134
+ model_channels,
135
+ num_heads=self.num_heads,
136
+ mlp_ratio=self.mlp_ratio,
137
+ attn_mode=attn_mode,
138
+ window_size=window_size,
139
+ shift_sequence=shift_sequence,
140
+ shift_window=shift_window,
141
+ serialize_mode=serialize_mode,
142
+ use_checkpoint=self.use_checkpoint,
143
+ use_rope=(pe_mode == "rope"),
144
+ qk_rms_norm=self.qk_rms_norm,
145
+ )
146
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
147
+ ])
148
+
149
+ ctx_channels = []
150
+ if skin_skl_cross:
151
+ ctx_channels.append(model_channels_skl)
152
+ if skin_cross_from_geo:
153
+ ctx_channels.append(model_channels)
154
+ self.blocks_skin = nn.ModuleList([
155
+ SparseTransformerMultiContextCrossBlock(
156
+ model_channels_skin,
157
+ ctx_channels=ctx_channels,
158
+ num_heads=num_heads_skin,
159
+ mlp_ratio=self.mlp_ratio,
160
+ attn_mode=attn_mode,
161
+ attn_mode_cross=attn_mode,
162
+ window_size=window_size,
163
+ shift_sequence=shift_sequence,
164
+ shift_window=shift_window,
165
+ serialize_mode=serialize_mode,
166
+ use_checkpoint=self.use_checkpoint,
167
+ use_rope=(pe_mode == "rope"),
168
+ qk_rms_norm=self.qk_rms_norm,
169
+ cross_attn_cache_suffix='_skin',
170
+ )
171
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self, "attn_mode_cross")
172
+ ])
173
+
174
+ ctx_channels = []
175
+ if skin_skl_cross:
176
+ ctx_channels.append(model_channels_skin)
177
+ if skl_cross_from_geo:
178
+ ctx_channels.append(model_channels)
179
+ self.blocks_skl = nn.ModuleList([
180
+ SparseTransformerMultiContextCrossBlock(
181
+ model_channels_skl,
182
+ ctx_channels=ctx_channels,
183
+ num_heads=num_heads_skl,
184
+ mlp_ratio=self.mlp_ratio,
185
+ attn_mode=attn_mode,
186
+ attn_mode_cross=attn_mode,
187
+ window_size=window_size,
188
+ shift_sequence=shift_sequence,
189
+ shift_window=shift_window,
190
+ serialize_mode=serialize_mode,
191
+ use_checkpoint=self.use_checkpoint,
192
+ use_rope=(pe_mode == "rope"),
193
+ qk_rms_norm=self.qk_rms_norm,
194
+ cross_attn_cache_suffix='_skl',
195
+ )
196
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self, "attn_mode_cross")
197
+ ])
198
+
199
+ @property
200
+ def device(self) -> torch.device:
201
+ """
202
+ Return the device of the model.
203
+ """
204
+ return next(self.parameters()).device
205
+
206
+ def convert_to_fp16(self) -> None:
207
+ """
208
+ Convert the torso of the model to float16.
209
+ """
210
+ self.blocks.apply(convert_module_to_f16)
211
+ self.blocks_skl.apply(convert_module_to_f16)
212
+ self.blocks_skin.apply(convert_module_to_f16)
213
+
214
+ def convert_to_fp32(self) -> None:
215
+ """
216
+ Convert the torso of the model to float32.
217
+ """
218
+ self.blocks.apply(convert_module_to_f32)
219
+ self.blocks_skl.apply(convert_module_to_f32)
220
+ self.blocks_skin.apply(convert_module_to_f32)
221
+
222
+ def initialize_weights(self) -> None:
223
+ # Initialize transformer layers:
224
+ def _basic_init(module):
225
+ if isinstance(module, nn.Linear):
226
+ torch.nn.init.xavier_uniform_(module.weight)
227
+ if module.bias is not None:
228
+ nn.init.constant_(module.bias, 0)
229
+ self.apply(_basic_init)
230
+
231
+ def forward_input_layer(self, x: sp.SparseTensor, layer, pos_embedder) -> sp.SparseTensor:
232
+ h = layer(x)
233
+ if self.pe_mode == "ape":
234
+ h = h + pos_embedder(x.coords[:, 1:])
235
+ h = h.type(self.dtype)
236
+ return h
237
+
238
+ def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, x_skin: sp.SparseTensor) -> sp.SparseTensor:
239
+ h = self.forward_input_layer(x, self.input_layer, self.pos_embedder)
240
+ h_skl = self.forward_input_layer(x_skl, self.input_layer_skl, self.pos_embedder_skl)
241
+ h_skin = self.forward_input_layer(x_skin, self.input_layer_skin, self.pos_embedder_skin)
242
+
243
+ for block, block_skl, block_skin in zip(self.blocks, self.blocks_skl, self.blocks_skin):
244
+ f, f_skl, f_skin = h, h_skl, h_skin
245
+ h = block(f)
246
+ skl_contexts, skin_contexts = [], []
247
+ if self.skin_skl_cross:
248
+ skl_contexts.append(f_skin)
249
+ skin_contexts.append(f_skl)
250
+ if self.skl_cross_from_geo:
251
+ skl_contexts.append(f)
252
+ if self.skin_cross_from_geo:
253
+ skin_contexts.append(f)
254
+ h_skl = block_skl(f_skl, skl_contexts)
255
+ h_skin = block_skin(f_skin, skin_contexts)
256
+ return h, h_skl, h_skin
anigen/models/structured_latent_vae/anigen_decoder.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import contextlib
3
+ import torch
4
+ import torch.nn as nn
5
+ from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock
6
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
7
+ from ...modules import sparse as sp
8
+ from ...representations import MeshExtractResult
9
+ from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton
10
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
11
+ from pytorch3d.ops import knn_points
12
+ from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
13
+ from .skin_models import SKIN_MODEL_DICT
14
+ import torch.nn.functional as F
15
+ from ...representations.skeleton.grouping import GROUPING_STRATEGIES
16
+
17
+
18
+ class SparseSubdivideBlock3d(nn.Module):
19
+ """
20
+ A 3D subdivide block that can subdivide the sparse tensor.
21
+
22
+ Args:
23
+ channels: channels in the inputs and outputs.
24
+ out_channels: if specified, the number of output channels.
25
+ num_groups: the number of groups for the group norm.
26
+ """
27
+ def __init__(
28
+ self,
29
+ channels: int,
30
+ resolution: int,
31
+ out_channels: Optional[int] = None,
32
+ num_groups: int = 32,
33
+ sub_divide: bool = True,
34
+ conv_as_residual: bool = False,
35
+ ):
36
+ super().__init__()
37
+ self.channels = channels
38
+ self.resolution = resolution
39
+ self.out_resolution = resolution * 2 if sub_divide else resolution
40
+ self.out_channels = out_channels or channels
41
+ self.sub_divide = sub_divide
42
+ self.conv_as_residual = conv_as_residual
43
+
44
+ self.act_layers = nn.Sequential(
45
+ sp.SparseGroupNorm32(num_groups, channels),
46
+ sp.SparseSiLU()
47
+ )
48
+
49
+ self.sub = sp.SparseSubdivide() if sub_divide else nn.Identity()
50
+
51
+ self.out_layers = nn.Sequential(
52
+ sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
53
+ sp.SparseGroupNorm32(num_groups, self.out_channels),
54
+ sp.SparseSiLU(),
55
+ zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
56
+ )
57
+
58
+ if self.out_channels == channels and not self.conv_as_residual:
59
+ self.skip_connection = nn.Identity()
60
+ else:
61
+ self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
62
+
63
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
64
+ """
65
+ Apply the block to a Tensor, conditioned on a timestep embedding.
66
+ SparseConv3d
67
+ Args:
68
+ x: an [N x C x ...] Tensor of features.
69
+ Returns:
70
+ an [N x C x ...] Tensor of outputs.
71
+ """
72
+ h = self.act_layers(x)
73
+ h = self.sub(h)
74
+ x = self.sub(x)
75
+ h = self.out_layers(h)
76
+ h = h + self.skip_connection(x)
77
+ return h
78
+
79
+
80
+ class SparseDownsampleWithCache(nn.Module):
81
+ """SparseDownsample that stores upsample caches under a unique suffix.
82
+
83
+ This avoids cache-key collisions when stacking multiple down/up stages.
84
+ """
85
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
86
+ super().__init__()
87
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
88
+ self.cache_suffix = cache_suffix
89
+ self._down = sp.SparseDownsample(self.factor)
90
+
91
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
92
+ out = self._down(x)
93
+
94
+ dim = out.coords.shape[-1] - 1
95
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
96
+ k_coords = f'upsample_{factor}_coords'
97
+ k_layout = f'upsample_{factor}_layout'
98
+ k_idx = f'upsample_{factor}_idx'
99
+ coords = out.get_spatial_cache(k_coords)
100
+ layout = out.get_spatial_cache(k_layout)
101
+ idx = out.get_spatial_cache(k_idx)
102
+ if any(v is None for v in [coords, layout, idx]):
103
+ raise ValueError('Downsample cache not found after SparseDownsample.')
104
+
105
+ # spconv expects int32 indices; SparseDownsample produces int64 coords.
106
+ if out.coords.dtype != torch.int32:
107
+ out = sp.SparseTensor(
108
+ out.feats,
109
+ out.coords.to(torch.int32),
110
+ out.shape,
111
+ out.layout,
112
+ scale=out._scale,
113
+ spatial_cache=out._spatial_cache,
114
+ )
115
+
116
+ out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords', coords)
117
+ out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout', layout)
118
+ out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx', idx)
119
+ # Remove unsuffixed keys to prevent later stages overwriting them.
120
+ try:
121
+ del out._spatial_cache[k_coords]
122
+ del out._spatial_cache[k_layout]
123
+ del out._spatial_cache[k_idx]
124
+ except Exception:
125
+ pass
126
+ return out
127
+
128
+
129
+ class SparseUpsampleWithCache(nn.Module):
130
+ """SparseUpsample that reads upsample caches under a unique suffix."""
131
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
132
+ super().__init__()
133
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
134
+ self.cache_suffix = cache_suffix
135
+
136
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
137
+ dim = x.coords.shape[-1] - 1
138
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
139
+ new_coords = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords')
140
+ new_layout = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout')
141
+ idx = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx')
142
+ if any(v is None for v in [new_coords, new_layout, idx]):
143
+ raise ValueError('Upsample cache not found. Must be paired with SparseDownsampleWithCache.')
144
+ if new_coords.dtype != torch.int32:
145
+ new_coords = new_coords.to(torch.int32)
146
+ new_feats = x.feats[idx]
147
+ out = sp.SparseTensor(new_feats, new_coords, x.shape, new_layout)
148
+ out._scale = tuple([s * f for s, f in zip(x._scale, factor)])
149
+ out._spatial_cache = x._spatial_cache
150
+ return out
151
+
152
+
153
+ class SparseSkinUNetNLevel(nn.Module):
154
+ """A simple N-down/N-up sparse UNet for local smoothing.
155
+
156
+ Note: `SparseSubdivideBlock3d` uses `resolution` only to name spconv `indice_key`s.
157
+ We must provide distinct (and stage-appropriate) values per hierarchy to avoid
158
+ rulebook collisions across different coordinate sets.
159
+ """
160
+ def __init__(self, channels: int, base_resolution: int, num_groups: int = 32, num_levels: int = 3):
161
+ super().__init__()
162
+
163
+ if num_levels < 1:
164
+ raise ValueError(f"num_levels must be >= 1, got {num_levels}")
165
+ self.channels = channels
166
+ self.base_resolution = int(base_resolution)
167
+ self.num_groups = num_groups
168
+ self.num_levels = int(num_levels)
169
+
170
+ def res_block(resolution: int):
171
+ return SparseSubdivideBlock3d(
172
+ channels=channels,
173
+ resolution=resolution,
174
+ out_channels=channels,
175
+ sub_divide=False,
176
+ conv_as_residual=True,
177
+ num_groups=num_groups,
178
+ )
179
+
180
+ # resolutions[i] corresponds to the i-th encoder stage (before downsample)
181
+ resolutions: List[int] = [max(1, self.base_resolution // (2 ** i)) for i in range(self.num_levels)]
182
+ bottom_resolution = max(1, self.base_resolution // (2 ** self.num_levels))
183
+
184
+ self.enc = nn.ModuleList([res_block(r) for r in resolutions])
185
+ self.down = nn.ModuleList([SparseDownsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
186
+ self.mid = res_block(bottom_resolution)
187
+
188
+ # Decoder blocks operate at the same resolutions as encoder blocks.
189
+ self.up = nn.ModuleList([SparseUpsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
190
+ self.fuse = nn.ModuleList([sp.SparseLinear(channels * 2, channels) for _ in range(self.num_levels)])
191
+ self.dec = nn.ModuleList([res_block(r) for r in resolutions])
192
+
193
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
194
+ in_dtype = x.feats.dtype
195
+ if x.coords.dtype != torch.int32:
196
+ x = sp.SparseTensor(
197
+ x.feats,
198
+ x.coords.to(torch.int32),
199
+ x.shape,
200
+ x.layout,
201
+ scale=x._scale,
202
+ spatial_cache=x._spatial_cache,
203
+ )
204
+
205
+ # spconv implicit_gemm has a runtime tuner that can fail for some sparse
206
+ # rulebooks under AMP + fp16/bf16. Running UNet convs in fp32 avoids that.
207
+ if hasattr(torch, 'autocast'):
208
+ autocast_ctx = torch.autocast(device_type=x.device.type, enabled=False)
209
+ else:
210
+ # Older torch fallback
211
+ autocast_ctx = torch.cuda.amp.autocast(enabled=False) if x.device.type == 'cuda' else contextlib.nullcontext()
212
+
213
+ with autocast_ctx:
214
+ x_fp32 = x if x.feats.dtype == torch.float32 else x.replace(x.feats.float())
215
+
216
+ skips: List[sp.SparseTensor] = []
217
+ h = x_fp32
218
+ for i in range(self.num_levels):
219
+ s = self.enc[i](h)
220
+ skips.append(s)
221
+ h = self.down[i](s)
222
+
223
+ h = self.mid(h)
224
+
225
+ for i in reversed(range(self.num_levels)):
226
+ h_up = self.up[i](h)
227
+ s = skips[i]
228
+ h = self.fuse[i](h_up.replace(torch.cat([h_up.feats, s.feats], dim=-1)))
229
+ h = self.dec[i](h)
230
+
231
+ u0 = h
232
+
233
+ if in_dtype != u0.feats.dtype:
234
+ u0 = u0.replace(u0.feats.to(dtype=in_dtype))
235
+ return u0
236
+
237
+
238
+ class AniGenSLatMeshDecoder(AniGenSparseTransformerBase):
239
+ def __init__(
240
+ self,
241
+ resolution: int,
242
+ model_channels: int,
243
+ model_channels_skl: int,
244
+ model_channels_skin: int,
245
+
246
+ latent_channels: int,
247
+ latent_channels_skl: int,
248
+ latent_channels_vertskin: int,
249
+
250
+ num_blocks: int,
251
+ num_heads: Optional[int] = None,
252
+ num_head_channels: Optional[int] = 64,
253
+
254
+ num_heads_skl: int = 32,
255
+ num_heads_skin: int = 32,
256
+
257
+ skin_cross_from_groupped: bool = False,
258
+ h_skin_unet_num_levels: int = 4,
259
+
260
+ skin_decoder_config: Optional[Dict[str, Any]] = {},
261
+
262
+ upsample_skl: bool = False,
263
+ skl_defined_on_center: bool = True,
264
+ mlp_ratio: float = 4,
265
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
266
+ attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
267
+ window_size: int = 8,
268
+ pe_mode: Literal["ape", "rope"] = "ape",
269
+ use_fp16: bool = False,
270
+ use_checkpoint: bool = False,
271
+ qk_rms_norm: bool = False,
272
+ representation_config: dict = None,
273
+
274
+ use_pretrain_branch: bool = True,
275
+ freeze_pretrain_branch: bool = True,
276
+ modules_to_freeze: Optional[List[str]] = ["blocks", "upsample", "out_layer", "skin_decoder"],
277
+
278
+ skin_cross_from_geo: bool = False,
279
+ skl_cross_from_geo: bool = False,
280
+ skin_skl_cross: bool = False,
281
+ skin_ae_name: str = "SkinAE",
282
+
283
+ normalize_z: bool = False,
284
+ normalize_scale: float = 1.0,
285
+
286
+ jp_residual_fields: bool = True,
287
+ jp_hyper_continuous: bool = True,
288
+
289
+ grouping_strategy: Literal["mean_shift", "threshold"] = "mean_shift",
290
+
291
+ vertex_skin_feat_interp_sparse: bool = False,
292
+ vertex_skin_feat_interp_nearest: bool = False,
293
+ vertex_skin_feat_interp_use_deformed_grid: bool = False,
294
+ vertex_skin_feat_interp_trilinear: bool = False,
295
+ flexicube_disable_deform: bool = False,
296
+ vertex_skin_feat_nodeform_trilinear: bool = False,
297
+ ):
298
+ super().__init__(
299
+ in_channels=latent_channels,
300
+ in_channels_skl=latent_channels_skl,
301
+ in_channels_skin=latent_channels_vertskin,
302
+ model_channels=model_channels,
303
+ model_channels_skl=model_channels_skl,
304
+ model_channels_skin=model_channels_skin,
305
+ num_blocks=num_blocks,
306
+ num_heads=num_heads,
307
+ num_heads_skl=num_heads_skl,
308
+ num_heads_skin=num_heads_skin,
309
+ num_head_channels=num_head_channels,
310
+ mlp_ratio=mlp_ratio,
311
+ attn_mode=attn_mode,
312
+ attn_mode_cross=attn_mode_cross,
313
+ window_size=window_size,
314
+ pe_mode=pe_mode,
315
+ use_fp16=use_fp16,
316
+ use_checkpoint=use_checkpoint,
317
+ qk_rms_norm=qk_rms_norm,
318
+ skin_cross_from_geo=skin_cross_from_geo,
319
+ skl_cross_from_geo=skl_cross_from_geo,
320
+ skin_skl_cross=skin_skl_cross,
321
+ )
322
+ self.pretrain_class_name = ["AniGenElasticSLatMeshDecoder", skin_ae_name]
323
+ self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_decoder"}
324
+ self.latent_channels = latent_channels
325
+ self.latent_channels_skl = latent_channels_skl
326
+ self.latent_channels_vertskin = latent_channels_vertskin
327
+ self.jp_residual_fields = jp_residual_fields
328
+ self.jp_hyper_continuous = jp_hyper_continuous
329
+ self.grouping_func = GROUPING_STRATEGIES[grouping_strategy]
330
+ self.skin_cross_from_groupped = skin_cross_from_groupped
331
+
332
+ self.normalize_z = normalize_z
333
+ self.normalize_scale = normalize_scale
334
+
335
+ skin_decoder_config['use_fp16'] = use_fp16
336
+ self.skin_decoder = SKIN_MODEL_DICT[skin_decoder_config.pop('model_type')](**skin_decoder_config)
337
+ self.skin_feat_channels = self.skin_decoder.skin_feat_channels
338
+
339
+ # Optional local smoothing UNet on h_skin (independent of grouped cross-attn).
340
+ # If `h_skin_unet_num_levels < 0`, UNet is disabled.
341
+ self.h_skin_unet_num_levels = int(h_skin_unet_num_levels)
342
+ if self.h_skin_unet_num_levels >= 1:
343
+ self.h_skin_unet = SparseSkinUNetNLevel(
344
+ model_channels_skin,
345
+ base_resolution=resolution,
346
+ num_levels=self.h_skin_unet_num_levels,
347
+ )
348
+ else:
349
+ self.h_skin_unet = None
350
+
351
+ if self.skin_cross_from_groupped:
352
+ # Trainable parent feature for root joints (where parent_idx < 0).
353
+ self.root_parent_feat = nn.Parameter(torch.zeros(self.skin_feat_channels))
354
+
355
+ # Joint feature preprocessing: [joint_skin, fourier(joint_xyz), parent_skin] -> proj -> self-attn
356
+ self.joints_pos_embedder = FreqPositionalEmbedder(
357
+ in_dim=3,
358
+ include_input=True,
359
+ max_freq_log2=6,
360
+ num_freqs=6,
361
+ log_sampling=True,
362
+ )
363
+ joints_pe_dim = self.joints_pos_embedder.out_dim
364
+ joints_in_dim = self.skin_feat_channels + joints_pe_dim + self.skin_feat_channels
365
+ self.joints_ctx_channels = model_channels_skin
366
+ self.joints_in_proj = nn.Sequential(
367
+ nn.Linear(joints_in_dim, self.joints_ctx_channels, bias=True),
368
+ nn.SiLU(),
369
+ nn.LayerNorm(self.joints_ctx_channels, elementwise_affine=True),
370
+ )
371
+ self.joints_self_attn = nn.ModuleList([
372
+ SparseTransformerBlock(
373
+ self.joints_ctx_channels,
374
+ num_heads=num_heads_skin,
375
+ mlp_ratio=self.mlp_ratio,
376
+ attn_mode="full",
377
+ window_size=None,
378
+ use_checkpoint=self.use_checkpoint,
379
+ use_rope=False,
380
+ qk_rms_norm=self.qk_rms_norm,
381
+ ln_affine=True,
382
+ ) for _ in range(4)
383
+ ])
384
+
385
+ # Coordinate PE for h_skin before cross-attn: coords in [-1, 1] -> Fourier PE -> proj(C), concat, fuse back to C.
386
+ self.h_skin_coord_embedder = FreqPositionalEmbedder(
387
+ in_dim=3,
388
+ include_input=True,
389
+ max_freq_log2=6,
390
+ num_freqs=6,
391
+ log_sampling=True,
392
+ )
393
+ h_skin_pe_dim = self.h_skin_coord_embedder.out_dim
394
+ self.h_skin_coord_proj = nn.Linear(h_skin_pe_dim, model_channels_skin, bias=True)
395
+ self.h_skin_coord_fuse = sp.SparseLinear(model_channels_skin * 2, model_channels_skin)
396
+
397
+ self.skin_cross_groupped_net = SparseTransformerMultiContextCrossBlock(
398
+ model_channels_skin,
399
+ # Context includes processed joint tokens + raw joint skin feats (skip connection).
400
+ ctx_channels=[self.joints_ctx_channels + self.skin_feat_channels],
401
+ num_heads=num_heads_skin,
402
+ mlp_ratio=self.mlp_ratio,
403
+ attn_mode="full",
404
+ attn_mode_cross="full",
405
+ cross_attn_cache_suffix='_skin_cross_from_groupped',
406
+ )
407
+
408
+ self.resolution = resolution
409
+ self.use_pretrain_branch = use_pretrain_branch
410
+ self.freeze_pretrain_branch = freeze_pretrain_branch
411
+ self.upsample_skl = upsample_skl
412
+ self.rep_config = representation_config
413
+ self.mesh_extractor = AniGenSparseFeatures2Mesh(
414
+ res=self.resolution*4,
415
+ use_color=self.rep_config.get('use_color', False),
416
+ skin_feat_channels=self.skin_feat_channels,
417
+ predict_skin=True,
418
+ vertex_skin_feat_interp_sparse=vertex_skin_feat_interp_sparse,
419
+ vertex_skin_feat_interp_nearest=vertex_skin_feat_interp_nearest,
420
+ vertex_skin_feat_interp_use_deformed_grid=vertex_skin_feat_interp_use_deformed_grid,
421
+ vertex_skin_feat_interp_trilinear=vertex_skin_feat_interp_trilinear,
422
+ flexicube_disable_deform=flexicube_disable_deform,
423
+ vertex_skin_feat_nodeform_trilinear=vertex_skin_feat_nodeform_trilinear,
424
+ )
425
+ self.out_channels = self.mesh_extractor.feats_channels
426
+ self.upsample = nn.ModuleList([
427
+ SparseSubdivideBlock3d(
428
+ channels=model_channels,
429
+ resolution=resolution,
430
+ out_channels=model_channels // 4
431
+ ),
432
+ SparseSubdivideBlock3d(
433
+ channels=model_channels // 4,
434
+ resolution=resolution * 2,
435
+ out_channels=model_channels // 8
436
+ )
437
+ ])
438
+ upsample_skin_blocks = []
439
+ upsample_skin_blocks.extend([
440
+ SparseSubdivideBlock3d(
441
+ channels=model_channels_skin,
442
+ resolution=resolution,
443
+ out_channels=model_channels // 4
444
+ ),
445
+ SparseSubdivideBlock3d(
446
+ channels=model_channels // 4,
447
+ resolution=resolution * 2,
448
+ out_channels=model_channels // 8
449
+ )
450
+ ])
451
+
452
+ self.upsample_skin_net = nn.ModuleList(upsample_skin_blocks)
453
+ self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
454
+ self.out_layer_skin = sp.SparseLinear(model_channels // 8, self.skin_feat_channels*8)
455
+ self.out_layer_skl_skin = sp.SparseLinear(model_channels // 8 if upsample_skl else model_channels_skl, self.skin_feat_channels if skl_defined_on_center else self.skin_feat_channels * 8)
456
+ self.use_conf_jp = self.rep_config.get('use_conf_jp', False) or self.jp_hyper_continuous
457
+ self.use_conf_skin = self.rep_config.get('use_conf_skin', False)
458
+
459
+ res_skl = self.resolution * 4 if self.upsample_skl else self.resolution
460
+ self.skeleton_extractor = AniGenSklFeatures2Skeleton(skin_feat_channels=self.skin_feat_channels, device=self.device, res=res_skl, use_conf_jp=self.use_conf_jp, use_conf_skin=self.use_conf_skin, predict_skin=True, defined_on_center=skl_defined_on_center, jp_hyper_continuous=self.jp_hyper_continuous, jp_residual_fields=self.jp_residual_fields)
461
+
462
+ self.out_channels_skl = self.skeleton_extractor.feats_channels
463
+ if self.upsample_skl:
464
+ self.upsample_skl_net = nn.ModuleList([
465
+ SparseSubdivideBlock3d(
466
+ channels=model_channels_skl,
467
+ resolution=resolution,
468
+ out_channels=model_channels // 4
469
+ ),
470
+ SparseSubdivideBlock3d(
471
+ channels=model_channels // 4,
472
+ resolution=resolution * 2,
473
+ out_channels=model_channels // 8
474
+ )
475
+ ])
476
+ self.out_layer_skl = sp.SparseLinear(model_channels // 8, self.out_channels_skl)
477
+ else:
478
+ self.out_layer_skl = sp.SparseLinear(model_channels_skl, self.out_channels_skl)
479
+
480
+ self.initialize_weights()
481
+ if use_fp16:
482
+ self.convert_to_fp16()
483
+ else:
484
+ self.convert_to_fp32()
485
+
486
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
487
+ for module in modules_to_freeze:
488
+ if hasattr(self, module):
489
+ mod = getattr(self, module)
490
+ if isinstance(mod, nn.ModuleList):
491
+ for m in mod:
492
+ for name, param in m.named_parameters():
493
+ if 'lora' not in name:
494
+ param.requires_grad = False
495
+ elif isinstance(mod, nn.Module):
496
+ for name, param in mod.named_parameters():
497
+ if 'lora' not in name:
498
+ param.requires_grad = False
499
+ elif isinstance(mod, torch.Tensor):
500
+ if mod.requires_grad:
501
+ mod.requires_grad = False
502
+
503
+ def initialize_weights(self) -> None:
504
+ super().initialize_weights()
505
+ scale = 1e-4
506
+ # Kaiming initialization for output layers (better for ReLU/SiLU-like activations)
507
+ nn.init.kaiming_normal_(self.out_layer.weight, mode='fan_in', nonlinearity='relu')
508
+ self.out_layer.weight.data.mul_(scale)
509
+ nn.init.constant_(self.out_layer.bias, 0)
510
+
511
+ nn.init.kaiming_normal_(self.out_layer_skl.weight, mode='fan_in', nonlinearity='relu')
512
+ self.out_layer_skl.weight.data.mul_(scale)
513
+ nn.init.constant_(self.out_layer_skl.bias, 0)
514
+
515
+ # Initialize skin layer:
516
+ self.skin_decoder.initialize_weights()
517
+ nn.init.kaiming_normal_(self.out_layer_skin.weight, mode='fan_in', nonlinearity='relu')
518
+ self.out_layer_skin.weight.data.mul_(scale)
519
+ nn.init.constant_(self.out_layer_skin.bias, 0)
520
+
521
+ nn.init.kaiming_normal_(self.out_layer_skl_skin.weight, mode='fan_in', nonlinearity='relu')
522
+ self.out_layer_skl_skin.weight.data.mul_(scale)
523
+ nn.init.constant_(self.out_layer_skl_skin.bias, 0)
524
+
525
+ def convert_to_fp16(self) -> None:
526
+ """
527
+ Convert the torso of the model to float16.
528
+ """
529
+ super().convert_to_fp16()
530
+ self.upsample.apply(convert_module_to_f16)
531
+ self.upsample_skin_net.apply(convert_module_to_f16)
532
+ if self.upsample_skl:
533
+ self.upsample_skl_net.apply(convert_module_to_f16)
534
+ if self.skin_cross_from_groupped:
535
+ # Joint preprocessing and cross-attn should match model dtype.
536
+ self.root_parent_feat.data = self.root_parent_feat.data.half()
537
+ self.joints_in_proj.apply(convert_module_to_f16)
538
+ self.joints_self_attn.apply(convert_module_to_f16)
539
+
540
+ # `convert_module_to_f16` doesn't include `nn.LayerNorm`, so cast LN params explicitly.
541
+ for _m in self.joints_in_proj.modules():
542
+ if isinstance(_m, nn.LayerNorm):
543
+ if _m.weight is not None:
544
+ _m.weight.data = _m.weight.data.half()
545
+ if _m.bias is not None:
546
+ _m.bias.data = _m.bias.data.half()
547
+
548
+ # IMPORTANT: `SparseTransformerBlock` uses `LayerNorm32` which internally
549
+ # normalizes in fp32 (`x.float()`), so its parameters must stay fp32.
550
+ for _m in self.joints_self_attn.modules():
551
+ if isinstance(_m, nn.LayerNorm):
552
+ if _m.weight is not None:
553
+ _m.weight.data = _m.weight.data.float()
554
+ if _m.bias is not None:
555
+ _m.bias.data = _m.bias.data.float()
556
+
557
+ self.skin_cross_groupped_net.apply(convert_module_to_f16)
558
+ self.h_skin_coord_proj.apply(convert_module_to_f16)
559
+ self.h_skin_coord_fuse.apply(convert_module_to_f16)
560
+
561
+ # UNet is executed in fp32 (see `SparseSkinUNetNLevel.forward`), so keep its
562
+ # weights in fp32 to avoid dtype mismatches inside spconv.
563
+ if self.h_skin_unet is not None:
564
+ self.h_skin_unet.apply(convert_module_to_f32)
565
+ self.skin_decoder.convert_to_fp16()
566
+
567
+ def convert_to_fp32(self) -> None:
568
+ """
569
+ Convert the torso of the model to float32.
570
+ """
571
+ super().convert_to_fp32()
572
+ self.upsample.apply(convert_module_to_f32)
573
+ self.upsample_skin_net.apply(convert_module_to_f32)
574
+ if self.upsample_skl:
575
+ self.upsample_skl_net.apply(convert_module_to_f32)
576
+ if self.skin_cross_from_groupped:
577
+ self.root_parent_feat.data = self.root_parent_feat.data.float()
578
+ self.joints_in_proj.apply(convert_module_to_f32)
579
+ self.joints_self_attn.apply(convert_module_to_f32)
580
+
581
+ for _m in self.joints_in_proj.modules():
582
+ if isinstance(_m, nn.LayerNorm):
583
+ if _m.weight is not None:
584
+ _m.weight.data = _m.weight.data.float()
585
+ if _m.bias is not None:
586
+ _m.bias.data = _m.bias.data.float()
587
+ for _m in self.joints_self_attn.modules():
588
+ if isinstance(_m, nn.LayerNorm):
589
+ if _m.weight is not None:
590
+ _m.weight.data = _m.weight.data.float()
591
+ if _m.bias is not None:
592
+ _m.bias.data = _m.bias.data.float()
593
+
594
+ self.skin_cross_groupped_net.apply(convert_module_to_f32)
595
+ self.h_skin_coord_proj.apply(convert_module_to_f32)
596
+ self.h_skin_coord_fuse.apply(convert_module_to_f32)
597
+ if self.h_skin_unet is not None:
598
+ self.h_skin_unet.apply(convert_module_to_f32)
599
+ self.skin_decoder.convert_to_fp32()
600
+
601
+ def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
602
+ """
603
+ Convert a batch of network outputs to 3D representations.
604
+
605
+ Args:
606
+ x: The [N x * x C] sparse tensor output by the network.
607
+
608
+ Returns:
609
+ list of representations
610
+ """
611
+ ret = []
612
+ for i in range(x.shape[0]):
613
+ mesh = self.mesh_extractor(x[i], training=self.training)
614
+ ret.append(mesh)
615
+ return ret
616
+
617
+ def to_representation_skl(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
618
+ """
619
+ Convert a batch of network outputs to skeleton representations.
620
+
621
+ Args:
622
+ x: The [N x * x C] sparse tensor output by the network.
623
+
624
+ Returns:
625
+ list of skeleton representations
626
+ """
627
+ ret = []
628
+ for i in range(x.shape[0]):
629
+ skl = self.skeleton_extractor(x[i], training=self.training)
630
+ ret.append(skl)
631
+ return ret
632
+
633
+ def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, gt_joints=None, gt_parents=None) -> List[MeshExtractResult]:
634
+ x0 = x
635
+ x_skin = sp.SparseTensor(feats=x0.feats[:, self.latent_channels:], coords=x0.coords.clone())
636
+ x = x0.replace(x0.feats[:, :self.latent_channels])
637
+ if self.normalize_z:
638
+ x_skin = x_skin.replace(F.normalize(x_skin.feats, dim=-1))
639
+ x_skl = x_skl.replace(F.normalize(x_skl.feats, dim=-1))
640
+
641
+ # Backbone forward
642
+ h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
643
+
644
+ # Optional smoothing on h_skin.
645
+ if self.h_skin_unet is not None:
646
+ h_skin = self.h_skin_unet(h_skin)
647
+
648
+ # Skeleton prediction
649
+ if self.upsample_skl:
650
+ for block_skl in self.upsample_skl_net:
651
+ h_skl = block_skl(h_skl)
652
+ h_skl_middle = h_skl.type(x_skl.dtype)
653
+ h_skl = self.out_layer_skl(h_skl_middle)
654
+ h_skl_skin = self.out_layer_skl_skin(h_skl_middle)
655
+ h_skl = h_skl.replace(torch.cat([h_skl.feats, h_skl_skin.feats], dim=-1))
656
+ skeletons = self.to_representation_skl(h_skl)
657
+ skin_feats_joints_list = self.skeleton_grouping(skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
658
+
659
+ # Skin cross with grouped joint features
660
+ if self.skin_cross_from_groupped:
661
+ coords_xyz = h_skin.coords[:, 1:].to(device=h_skin.device, dtype=torch.float32)
662
+ coords_norm = (coords_xyz + 0.5) / self.resolution * 2.0 - 1.0
663
+ coords_pe = self.h_skin_coord_embedder(coords_norm)
664
+ coords_pe = coords_pe.to(device=h_skin.device, dtype=h_skin.feats.dtype)
665
+ coords_pe = self.h_skin_coord_proj(coords_pe)
666
+ h_skin = h_skin.replace(torch.cat([h_skin.feats, coords_pe], dim=-1))
667
+ h_skin = self.h_skin_coord_fuse(h_skin)
668
+ joints_ctx = self._build_processed_joints_context(
669
+ skeletons,
670
+ skin_feats_joints_list,
671
+ device=h_skin.device,
672
+ dtype=h_skin.feats.dtype,
673
+ )
674
+ h_skin = self.skin_cross_groupped_net(h_skin, [joints_ctx])
675
+ for block in self.upsample_skin_net:
676
+ h_skin = block(h_skin)
677
+ h_skin = h_skin.type(x.dtype)
678
+ h_skin = self.out_layer_skin(h_skin)
679
+
680
+ # Mesh prediction
681
+ for block in self.upsample:
682
+ h = block(h)
683
+ h_middle = h.type(x.dtype)
684
+ h = self.out_layer(h_middle)
685
+ h = h.replace(torch.cat([h.feats, h_skin.feats], dim=-1))
686
+ meshes = self.to_representation(h)
687
+
688
+ self.skinweight_forward(meshes, skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
689
+ return meshes, skeletons
690
+
691
+ def _joints_feats_list_to_sparse(
692
+ self,
693
+ joints_feats_list: List[torch.Tensor],
694
+ device: Optional[torch.device] = None,
695
+ dtype: Optional[torch.dtype] = None,
696
+ ) -> sp.SparseTensor:
697
+ if device is None:
698
+ device = self.device
699
+ if dtype is None:
700
+ dtype = self.dtype
701
+ feats_per_batch: List[torch.Tensor] = []
702
+ for joints_feats in joints_feats_list:
703
+ joints_feats = joints_feats.to(device=device, dtype=dtype)
704
+ feats_per_batch.append(joints_feats)
705
+ feats = torch.cat(feats_per_batch, dim=0)
706
+ # Coords are [batch, x, y, z]. We encode token index into x and keep y/z = 0.
707
+ batch_indices: List[torch.Tensor] = []
708
+ x_indices: List[torch.Tensor] = []
709
+ for bi, joints_feats in enumerate(feats_per_batch):
710
+ ji = int(joints_feats.shape[0])
711
+ batch_indices.append(torch.full((ji,), bi, device=device, dtype=torch.int32))
712
+ x_indices.append(torch.arange(ji, device=device, dtype=torch.int32))
713
+ b = torch.cat(batch_indices, dim=0)
714
+ x = torch.cat(x_indices, dim=0)
715
+ yz = torch.zeros((x.shape[0], 2), device=device, dtype=torch.int32)
716
+ coords = torch.cat([b[:, None], x[:, None], yz], dim=1)
717
+ return sp.SparseTensor(feats=feats, coords=coords)
718
+
719
+ def _build_processed_joints_context(
720
+ self,
721
+ skeletons: List[Any],
722
+ skin_feats_joints_list: List[torch.Tensor],
723
+ device: torch.device,
724
+ dtype: torch.dtype,
725
+ ) -> sp.SparseTensor:
726
+ processed: List[torch.Tensor] = []
727
+ raw_skin: List[torch.Tensor] = []
728
+ for rep_skl, skin_feats_joints in zip(skeletons, skin_feats_joints_list):
729
+ joints = rep_skl.joints_grouped
730
+ parents = rep_skl.parents_grouped
731
+ if joints is None or parents is None:
732
+ raise ValueError('Expected grouped joints/parents for skin_cross_from_groupped.')
733
+ joints = joints.to(device=device, dtype=dtype)
734
+ parents = parents.to(device=device)
735
+ skin_feats_joints = skin_feats_joints.to(device=device, dtype=dtype)
736
+ raw_skin.append(skin_feats_joints)
737
+
738
+ pe = self.joints_pos_embedder(joints).to(device=device, dtype=dtype)
739
+
740
+ # Parent skin features (root uses trainable parameter)
741
+ parent_idx = parents.to(torch.long)
742
+ valid = parent_idx >= 0
743
+ root_feat = self.root_parent_feat.to(device=device, dtype=dtype)
744
+ parent_feat_root = root_feat.unsqueeze(0).expand(skin_feats_joints.shape[0], -1)
745
+ parent_feat_gather = skin_feats_joints[parent_idx.clamp(min=0)]
746
+ parent_feat = torch.where(valid.unsqueeze(1), parent_feat_gather, parent_feat_root)
747
+
748
+ joint_in = torch.cat([skin_feats_joints, pe, parent_feat], dim=-1)
749
+ joint_h = self.joints_in_proj(joint_in)
750
+ processed.append(joint_h)
751
+
752
+ joints_ctx = self._joints_feats_list_to_sparse(processed, device=device, dtype=dtype)
753
+ for blk in self.joints_self_attn:
754
+ joints_ctx = blk(joints_ctx)
755
+ # Skip connection: concatenate original joint skin feats after self-attn.
756
+ joints_skip = self._joints_feats_list_to_sparse(raw_skin, device=device, dtype=dtype)
757
+ joints_ctx = joints_ctx.replace(torch.cat([joints_ctx.feats, joints_skip.feats], dim=-1))
758
+ return joints_ctx
759
+
760
+ def skeleton_grouping(self, reps_skl, gt_joints=None, gt_parents=None, skin_feats_skl_list=None, return_skin_pred_only=False):
761
+ skin_feats_joints_list = []
762
+ for i, rep_skl in zip(range(len(reps_skl)), reps_skl):
763
+ if gt_joints is not None:
764
+ joints_grouped = gt_joints[i]
765
+ parents_grouped = gt_parents[i]
766
+ elif rep_skl.joints_grouped is None:
767
+ with torch.no_grad():
768
+ joints_grouped, parents_grouped = self.grouping_func(joints=rep_skl.joints, parents=rep_skl.parents, joints_conf=rep_skl.conf_j, parents_conf=rep_skl.conf_p)
769
+ else:
770
+ joints_grouped = rep_skl.joints_grouped
771
+ parents_grouped = rep_skl.parents_grouped
772
+
773
+ if not return_skin_pred_only:
774
+ rep_skl.joints_grouped = joints_grouped
775
+ rep_skl.parents_grouped = parents_grouped
776
+
777
+ # Calculate NN indices for joints
778
+ positions_skl = rep_skl.positions
779
+ _, joints_nn_idx, _ = knn_points(positions_skl[None], joints_grouped[None].detach(), K=1, norm=2, return_nn=False)
780
+ joints_nn_idx = joints_nn_idx[0, :, 0]
781
+ skin_feats_skl = rep_skl.skin_feats if skin_feats_skl_list is None else skin_feats_skl_list[i]
782
+
783
+ # Average the predicted joint features
784
+ conf_skin = torch.sigmoid(rep_skl.conf_skin) if rep_skl.conf_skin is not None else torch.ones_like(skin_feats_skl[:, :1])
785
+
786
+ skin_feats_joints = torch.zeros([joints_grouped.shape[0], skin_feats_skl.shape[-1]], device=self.device, dtype=skin_feats_skl.dtype)
787
+ skin_feats_square_joints = skin_feats_joints.clone()
788
+ skin_conf_joints = torch.zeros([joints_grouped.shape[0], 1], device=self.device, dtype=skin_feats_skl.dtype)
789
+
790
+ skin_feats_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl * conf_skin)
791
+ skin_feats_square_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl.square() * conf_skin)
792
+ skin_conf_joints.scatter_add_(0, joints_nn_idx[:, None], conf_skin)
793
+
794
+ skin_feats_joints = skin_feats_joints / skin_conf_joints.clamp(min=1e-6)
795
+ skin_feats_square_joints = skin_feats_square_joints / skin_conf_joints.clamp(min=1e-6)
796
+ skin_feats_joints_var = skin_feats_square_joints - skin_feats_joints.square()
797
+ skin_feats_joints_var_loss = skin_feats_joints_var.mean()
798
+
799
+ if not return_skin_pred_only:
800
+ rep_skl.skin_feats_joints_var_loss = skin_feats_joints_var_loss
801
+ rep_skl.skin_feats_joints = skin_feats_joints
802
+ skin_feats_joints_list.append(skin_feats_joints)
803
+ return skin_feats_joints_list
804
+
805
+ def skinweight_forward(self, reps, reps_skl, gt_joints=None, gt_parents=None, return_skin_pred_only=False, skin_feats_verts_list=None, skin_feats_skl_list=None, *args, **kwargs):
806
+ if return_skin_pred_only:
807
+ skin_preds = []
808
+ if reps_skl[0].parents_grouped is None or return_skin_pred_only:
809
+ skin_feats_joints_list = self.skeleton_grouping(reps_skl, gt_joints=gt_joints, gt_parents=gt_parents, skin_feats_skl_list=skin_feats_skl_list, return_skin_pred_only=return_skin_pred_only)
810
+ else:
811
+ skin_feats_joints_list = [rep_skl.skin_feats_joints for rep_skl in reps_skl]
812
+ for i, rep, rep_skl in zip(range(len(reps)), reps, reps_skl):
813
+ # Joint skinning features
814
+ skin_feats_joints = skin_feats_joints_list[i]
815
+ # Vertex skinning features
816
+ skin_feats_verts = rep.vertex_skin_feats if skin_feats_verts_list is None else skin_feats_verts_list[i]
817
+ # Predict skin weights
818
+ parents_grouped = rep_skl.parents_grouped
819
+ skin_pred = self.skin_decoder(skin_feats_verts[None], skin_feats_joints[None], parents_grouped[None])
820
+ skin_pred = skin_pred[0]
821
+ if return_skin_pred_only:
822
+ skin_preds.append(skin_pred)
823
+ else:
824
+ reps_skl[i].skin_pred = skin_pred
825
+ if return_skin_pred_only:
826
+ return skin_preds
827
+
828
+
829
+ class AniGenElasticSLatMeshDecoder(SparseTransformerElasticMixin, AniGenSLatMeshDecoder):
830
+ """
831
+ Slat VAE Mesh decoder with elastic memory management.
832
+ Used for training with low VRAM.
833
+ """
834
+ pass
anigen/models/structured_latent_vae/anigen_encoder.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from ...modules import sparse as sp
6
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
7
+ from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
8
+ from pytorch3d.ops import knn_points
9
+ from .skin_models import SkinEncoder
10
+
11
+
12
+ def block_attn_config(self):
13
+ """
14
+ Return the attention configuration of the model.
15
+ """
16
+ for i in range(self.num_blocks):
17
+ if self.attn_mode == "shift_window":
18
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
19
+ elif self.attn_mode == "shift_sequence":
20
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
21
+ elif self.attn_mode == "shift_order":
22
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
23
+ elif self.attn_mode == "full":
24
+ yield "full", None, None, None, None
25
+ elif self.attn_mode == "swin":
26
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
27
+
28
+
29
+ class FeedForwardNet(nn.Module):
30
+ def __init__(self, channels: int, channels_out: int=None, mlp_ratio: float = 4.0):
31
+ super().__init__()
32
+ channels_out = channels if channels_out is None else channels_out
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(channels, int(channels * mlp_ratio)),
35
+ nn.GELU(approximate="tanh"),
36
+ nn.Linear(int(channels * mlp_ratio), channels_out),
37
+ )
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ return self.mlp(x)
41
+
42
+
43
+ class AniGenSLatEncoder(AniGenSparseTransformerBase):
44
+ def __init__(
45
+ self,
46
+ resolution: int,
47
+ in_channels: int,
48
+
49
+ model_channels: int,
50
+ model_channels_skl: int,
51
+ model_channels_skin: int,
52
+
53
+ latent_channels: int,
54
+ latent_channels_skl: int,
55
+ latent_channels_vertskin: int,
56
+
57
+ num_blocks: int,
58
+ num_heads: Optional[int] = None,
59
+ num_head_channels: Optional[int] = 64,
60
+
61
+ num_heads_skl: int = 32,
62
+ num_heads_skin: int = 32,
63
+
64
+ skl_pos_embed_freq: int = 10,
65
+ skin_encoder_config: Optional[Dict[str, Any]] = {},
66
+ encode_upsampled_skin_feat: bool = True,
67
+ skin_ae_name: Optional[str] = "SkinAE",
68
+
69
+ mlp_ratio: float = 4,
70
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
71
+ attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
72
+ window_size: int = 8,
73
+ pe_mode: Literal["ape", "rope"] = "ape",
74
+ use_fp16: bool = False,
75
+ use_checkpoint: bool = False,
76
+ qk_rms_norm: bool = False,
77
+
78
+ use_pretrain_branch: bool = True,
79
+ freeze_pretrain_branch: bool = True,
80
+ modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "out_layer", "skin_encoder"],
81
+
82
+ skin_cross_from_geo: bool = True,
83
+ skl_cross_from_geo: bool = True,
84
+ skin_skl_cross: bool = True,
85
+
86
+ latent_denoising: bool = True,
87
+ normalize_z: bool = True,
88
+ normalize_scale: float = 1.0,
89
+
90
+ jp_residual_fields: bool = False,
91
+ jp_hyper_continuous: bool = False,
92
+ ):
93
+ self.use_pretrain_branch = use_pretrain_branch
94
+ self.freeze_pretrain_branch = freeze_pretrain_branch
95
+ self.skl_pos_embed_freq = skl_pos_embed_freq
96
+ self.latent_denoising = latent_denoising
97
+ self.normalize_latent = normalize_z and latent_denoising
98
+ self.normalize_scale = normalize_scale
99
+ self.jp_residual_fields = jp_residual_fields
100
+ self.jp_hyper_continuous = jp_hyper_continuous
101
+
102
+ super().__init__(
103
+ in_channels=in_channels,
104
+ in_channels_skl=model_channels_skl,
105
+ in_channels_skin=model_channels_skin,
106
+ model_channels=model_channels,
107
+ model_channels_skl=model_channels_skl,
108
+ model_channels_skin=model_channels_skin,
109
+ num_blocks=num_blocks,
110
+ num_heads=num_heads,
111
+ num_heads_skl=num_heads_skl,
112
+ num_heads_skin=num_heads_skin,
113
+ num_head_channels=num_head_channels,
114
+ mlp_ratio=mlp_ratio,
115
+ attn_mode=attn_mode,
116
+ attn_mode_cross=attn_mode_cross,
117
+ window_size=window_size,
118
+ pe_mode=pe_mode,
119
+ use_fp16=use_fp16,
120
+ use_checkpoint=use_checkpoint,
121
+ qk_rms_norm=qk_rms_norm,
122
+ skin_cross_from_geo=skin_cross_from_geo,
123
+ skl_cross_from_geo=skl_cross_from_geo,
124
+ skin_skl_cross=skin_skl_cross,
125
+ )
126
+ self.pretrain_class_name = ["AniGenElasticSLatEncoder", skin_ae_name]
127
+ self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_encoder"}
128
+ self.resolution = resolution
129
+
130
+ self.latent_channels = latent_channels
131
+ self.latent_channels_skl = latent_channels_skl
132
+ self.latent_channels_vertskin = latent_channels_vertskin
133
+
134
+ skin_encoder_config['use_fp16'] = use_fp16
135
+ self.skin_encoder = SkinEncoder(**skin_encoder_config)
136
+ self.encode_upsampled_skin_feat = encode_upsampled_skin_feat
137
+ self.in_layer_skin = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels * (8 if encode_upsampled_skin_feat else 1), channels_out=model_channels_skin)
138
+
139
+ self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=4 if self.jp_hyper_continuous else 3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True)
140
+ self.root_embedding = nn.Parameter(torch.zeros(1, self.pos_embedder_fourier.out_dim))
141
+
142
+ # Channel Balance
143
+ self.in_layer_jp_skl = FeedForwardNet(channels=2 * self.pos_embedder_fourier.out_dim, channels_out=model_channels_skl//4)
144
+ self.in_layer_skin_skl = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels, channels_out=model_channels_skl-(model_channels_skl//4))
145
+
146
+ self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
147
+ if self.latent_denoising:
148
+ self.out_layer_skl = sp.SparseLinear(model_channels_skl, latent_channels_skl)
149
+ self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, latent_channels_vertskin)
150
+ else:
151
+ self.out_layer_skl = sp.SparseLinear(model_channels_skl, 2 * latent_channels_skl)
152
+ self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, 2 * latent_channels_vertskin)
153
+
154
+ self.initialize_weights()
155
+ if use_fp16:
156
+ self.convert_to_fp16()
157
+ else:
158
+ self.convert_to_fp32()
159
+
160
+ if 'all' in modules_to_freeze:
161
+ modules_to_freeze = list(set([k.split('.')[0] for k in self.state_dict().keys()]))
162
+ print(f"\033[93mFreezing all modules: {modules_to_freeze}\033[0m")
163
+ if self.use_pretrain_branch and self.freeze_pretrain_branch:
164
+ for module in modules_to_freeze:
165
+ if hasattr(self, module):
166
+ mod = getattr(self, module)
167
+ if isinstance(mod, nn.ModuleList):
168
+ for m in mod:
169
+ for name, param in m.named_parameters():
170
+ if 'lora' not in name:
171
+ param.requires_grad = False
172
+ elif isinstance(mod, nn.Module):
173
+ for name, param in mod.named_parameters():
174
+ if 'lora' not in name:
175
+ param.requires_grad = False
176
+ elif isinstance(mod, torch.Tensor):
177
+ if mod.requires_grad:
178
+ mod.requires_grad = False
179
+
180
+ def initialize_weights(self) -> None:
181
+ super().initialize_weights()
182
+ # Zero-out output layers:
183
+ nn.init.constant_(self.out_layer.weight, 0)
184
+ nn.init.constant_(self.out_layer.bias, 0)
185
+
186
+ def skeleton_embedding(self, x, x_skl, joints_list, parents_list, skin_list, gt_meshes, bvh_list=None):
187
+ res = self.resolution
188
+ feats_new = []
189
+ feats_skl_new = []
190
+ coords_new = []
191
+ coords_skl_new = []
192
+
193
+ joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
194
+ joints_pos_list = []
195
+
196
+ for i in range(len(joints_list)):
197
+ parent_idx = parents_list[i].clone()
198
+
199
+ coords_new.append(x[i].coords)
200
+ coords_skl_new.append(x_skl[i].coords)
201
+ coords_new[-1][:, 0] = i
202
+ coords_skl_new[-1][:, 0] = i
203
+
204
+ v_pos = (x[i].coords[:, 1:4] + 0.5) / res - 0.5
205
+ v_pos_skl = (x_skl[i].coords[:, 1:4] + 0.5) / res - 0.5
206
+ dist_nn_12, joints_nn_idx, _ = knn_points(v_pos_skl[None], joints_list[i][None], K=2, norm=2, return_nn=False)
207
+ joints_nn_idx = joints_nn_idx[0, :, 0]
208
+
209
+ # Skeleton positional embedding
210
+ joints_pos = joints_list[i][joints_nn_idx] - (v_pos_skl if self.jp_residual_fields else 0)
211
+ parents_pos = joints_list[i][parent_idx[joints_nn_idx]] - (v_pos_skl if self.jp_residual_fields else 0)
212
+ if self.jp_hyper_continuous:
213
+ factor = (1 - (dist_nn_12[0, :, 0:1] / (dist_nn_12[0, :, 1:2] + 1e-8)).clamp(max=1.0))
214
+ joints_pos = torch.cat([joints_pos, factor], dim=-1)
215
+ parents_pos = torch.cat([parents_pos, factor], dim=-1)
216
+ joints_pos_embed = self.pos_embedder_fourier(joints_pos)
217
+ parents_pos_embed = self.pos_embedder_fourier(parents_pos)
218
+ parents_pos_embed = torch.where(parent_idx[joints_nn_idx][:, None] == -1, self.root_embedding.expand_as(parents_pos_embed), parents_pos_embed)
219
+ jp_pos_embed_nn = torch.cat([joints_pos_embed, parents_pos_embed], dim=-1)
220
+ jp_pos_embed_nn = self.in_layer_jp_skl(jp_pos_embed_nn)
221
+
222
+ # Skeleton skin embedding
223
+ j_skin_embed_nn = joint_skin_embeds[i][joints_nn_idx]
224
+ j_skin_embed_nn = self.in_layer_skin_skl(j_skin_embed_nn)
225
+
226
+ # Concatenate
227
+ jp_skl_embed = torch.cat([jp_pos_embed_nn, j_skin_embed_nn], dim=-1)
228
+ feats_skl_new.append(jp_skl_embed)
229
+
230
+ if self.encode_upsampled_skin_feat:
231
+ # Create 8 sub-voxel points
232
+ offsets = torch.tensor([
233
+ [-1, -1, -1], [-1, -1, 1], [-1, 1, -1], [-1, 1, 1],
234
+ [1, -1, -1], [1, -1, 1], [1, 1, -1], [1, 1, 1]
235
+ ], device=v_pos.device, dtype=v_pos.dtype) * (0.25 / res)
236
+ query_pos = v_pos.unsqueeze(1) + offsets.unsqueeze(0) # (N, 8, 3)
237
+ query_pos_flat = query_pos.view(-1, 3)
238
+ else:
239
+ query_pos_flat = v_pos
240
+
241
+ if bvh_list is not None:
242
+ bvh = bvh_list[i].to(v_pos.device)
243
+ _, face_id, uvw = bvh.unsigned_distance(query_pos_flat, return_uvw=True)
244
+ uvw = uvw.clamp(min=0.0)
245
+ uvw_sum = uvw.sum(dim=-1, keepdim=True).clamp_min(1e-6)
246
+ uvw = uvw / uvw_sum
247
+ face_id = gt_meshes[i]['faces'][face_id]
248
+ voxel_skin_embeds = (vert_skin_embeds[i][face_id] * uvw[..., None]).sum(1)
249
+ else:
250
+ gt_mesh_verts = gt_meshes[i]['vertices']
251
+ _, mesh_nn_idx, _ = knn_points(query_pos_flat[None], gt_mesh_verts[None], K=1, norm=2, return_nn=False)
252
+ mesh_nn_idx = mesh_nn_idx[0, :, 0]
253
+ voxel_skin_embeds = vert_skin_embeds[i][mesh_nn_idx]
254
+
255
+ voxel_skin_embeds = voxel_skin_embeds.view(v_pos.shape[0], -1)
256
+ voxel_skin_embeds = self.in_layer_skin(voxel_skin_embeds)
257
+ feats_new.append(voxel_skin_embeds)
258
+ joints_pos_list.append(joints_pos)
259
+
260
+ x_new = sp.SparseTensor(coords=torch.cat(coords_new, dim=0), feats=torch.cat(feats_new, dim=0))
261
+ x_skl_new = sp.SparseTensor(coords=torch.cat(coords_skl_new, dim=0), feats=torch.cat(feats_skl_new, dim=0))
262
+
263
+ return x_new, x_skl_new, joint_skin_embeds, vert_skin_embeds, joints_pos_list
264
+
265
+ def encode_sample(self, x: sp.SparseTensor, out_layer: sp.SparseLinear, sample_posterior: bool = True, latent_denoising: bool = False):
266
+ x = x.type(torch.float32)
267
+ x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:]))
268
+ x = out_layer(x)
269
+ if latent_denoising:
270
+ if self.normalize_latent:
271
+ x = x.replace(nn.functional.normalize(x.feats, dim=-1) * self.normalize_scale)
272
+ mean, logvar = x.feats, torch.zeros_like(x.feats)
273
+ else:
274
+ mean, logvar = x.feats.chunk(2, dim=-1)
275
+ if sample_posterior and not latent_denoising:
276
+ std = torch.exp(0.5 * logvar)
277
+ z = mean + std * torch.randn_like(std)
278
+ else:
279
+ z = mean
280
+ z = x.replace(z)
281
+ if latent_denoising:
282
+ mean = mean.detach()
283
+ return z, mean, logvar
284
+
285
+ def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, sample_posterior=True, return_raw=False, return_skin_encoded=False, **kwargs):
286
+ x_skin, x_skl, joint_skin_embeds, vert_skin_embeds, joints_pos = self.skeleton_embedding(x, x_skl, kwargs.get('gt_joints'), kwargs.get('gt_parents'), kwargs.get('gt_skin'), kwargs.get('gt_mesh'), kwargs.get('bvh_list', None))
287
+ h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
288
+
289
+ z, mean, logvar = self.encode_sample(h, self.out_layer, sample_posterior, latent_denoising=False)
290
+ z_skl, mean_skl, logvar_skl = self.encode_sample(h_skl, self.out_layer_skl, sample_posterior, latent_denoising=self.latent_denoising)
291
+ z_skin, mean_skin, logvar_skin = self.encode_sample(h_skin, self.out_layer_vertskin, sample_posterior, latent_denoising=self.latent_denoising)
292
+
293
+ z = z.replace(torch.cat([z.feats, z_skin.feats], dim=-1))
294
+ mean, logvar = torch.cat([mean, mean_skin], dim=-1), torch.cat([logvar, logvar_skin], dim=-1)
295
+
296
+ if not return_skin_encoded:
297
+ # Ordinary return without skin encoded features
298
+ if return_raw:
299
+ return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
300
+ else:
301
+ return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
302
+ else:
303
+ # Return skin encoded features as well for checking
304
+ if return_raw:
305
+ return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
306
+ else:
307
+ return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
308
+
309
+ def encode_skin(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None):
310
+ joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
311
+ return joint_skin_embeds, vert_skin_embeds
312
+
313
+
314
+ class AniGenElasticSLatEncoder(SparseTransformerElasticMixin, AniGenSLatEncoder):
315
+ """
316
+ SLat VAE encoder with elastic memory management.
317
+ Used for training with low VRAM.
318
+ """
anigen/models/structured_latent_vae/base.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ...modules.utils import convert_module_to_f16, convert_module_to_f32
5
+ from ...modules import sparse as sp
6
+ from ...modules.transformer import AbsolutePositionEmbedder
7
+ from ...modules.sparse.transformer import SparseTransformerBlock
8
+
9
+
10
+ def block_attn_config(self):
11
+ """
12
+ Return the attention configuration of the model.
13
+ """
14
+ for i in range(self.num_blocks):
15
+ if self.attn_mode == "shift_window":
16
+ yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
17
+ elif self.attn_mode == "shift_sequence":
18
+ yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
19
+ elif self.attn_mode == "shift_order":
20
+ yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
21
+ elif self.attn_mode == "full":
22
+ yield "full", None, None, None, None
23
+ elif self.attn_mode == "swin":
24
+ yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
25
+
26
+
27
+ class SparseTransformerBase(nn.Module):
28
+ """
29
+ Sparse Transformer without output layers.
30
+ Serve as the base class for encoder and decoder.
31
+ """
32
+ def __init__(
33
+ self,
34
+ in_channels: int,
35
+ model_channels: int,
36
+ num_blocks: int,
37
+ num_heads: Optional[int] = None,
38
+ num_head_channels: Optional[int] = 64,
39
+ mlp_ratio: float = 4.0,
40
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
41
+ window_size: Optional[int] = None,
42
+ pe_mode: Literal["ape", "rope"] = "ape",
43
+ use_fp16: bool = False,
44
+ use_checkpoint: bool = False,
45
+ qk_rms_norm: bool = False,
46
+ ):
47
+ super().__init__()
48
+ self.in_channels = in_channels
49
+ self.model_channels = model_channels
50
+ self.num_blocks = num_blocks
51
+ self.window_size = window_size
52
+ self.num_heads = num_heads or model_channels // num_head_channels
53
+ self.mlp_ratio = mlp_ratio
54
+ self.attn_mode = attn_mode
55
+ self.pe_mode = pe_mode
56
+ self.use_fp16 = use_fp16
57
+ self.use_checkpoint = use_checkpoint
58
+ self.qk_rms_norm = qk_rms_norm
59
+ self.dtype = torch.float16 if use_fp16 else torch.float32
60
+
61
+ if pe_mode == "ape":
62
+ self.pos_embedder = AbsolutePositionEmbedder(model_channels)
63
+
64
+ self.input_layer = sp.SparseLinear(in_channels, model_channels)
65
+ self.blocks = nn.ModuleList([
66
+ SparseTransformerBlock(
67
+ model_channels,
68
+ num_heads=self.num_heads,
69
+ mlp_ratio=self.mlp_ratio,
70
+ attn_mode=attn_mode,
71
+ window_size=window_size,
72
+ shift_sequence=shift_sequence,
73
+ shift_window=shift_window,
74
+ serialize_mode=serialize_mode,
75
+ use_checkpoint=self.use_checkpoint,
76
+ use_rope=(pe_mode == "rope"),
77
+ qk_rms_norm=self.qk_rms_norm,
78
+ )
79
+ for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
80
+ ])
81
+
82
+ @property
83
+ def device(self) -> torch.device:
84
+ """
85
+ Return the device of the model.
86
+ """
87
+ return next(self.parameters()).device
88
+
89
+ def convert_to_fp16(self) -> None:
90
+ """
91
+ Convert the torso of the model to float16.
92
+ """
93
+ self.blocks.apply(convert_module_to_f16)
94
+
95
+ def convert_to_fp32(self) -> None:
96
+ """
97
+ Convert the torso of the model to float32.
98
+ """
99
+ self.blocks.apply(convert_module_to_f32)
100
+
101
+ def initialize_weights(self) -> None:
102
+ # Initialize transformer layers:
103
+ def _basic_init(module):
104
+ if isinstance(module, nn.Linear):
105
+ torch.nn.init.xavier_uniform_(module.weight)
106
+ if module.bias is not None:
107
+ nn.init.constant_(module.bias, 0)
108
+ self.apply(_basic_init)
109
+
110
+ def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
111
+ h = self.input_layer(x)
112
+ if self.pe_mode == "ape":
113
+ h = h + self.pos_embedder(x.coords[:, 1:])
114
+ h = h.type(self.dtype)
115
+ for block in self.blocks:
116
+ h = block(h)
117
+ return h
anigen/models/structured_latent_vae/skin_models.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import *
4
+ from ..sparse_elastic_mixin import SparseTransformerElasticMixin
5
+ from ...modules.transformer import TransformerBlock, FeedForwardNet
6
+ from .anigen_base import FreqPositionalEmbedder, TransformerCrossBlock
7
+ from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
8
+
9
+
10
+ class Embedder(nn.Module):
11
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int=None, depth: int = 4, mlp_ratio: float = 4.0, jp_embed_attn: bool = True):
12
+ super().__init__()
13
+ hidden_dim = out_dim if hidden_dim is None else hidden_dim
14
+ self.jp_embed_attn = jp_embed_attn
15
+ self.in_layer = FeedForwardNet(channels=in_dim, out_channels=hidden_dim, mlp_ratio=mlp_ratio)
16
+ if self.jp_embed_attn:
17
+ self.blocks = nn.ModuleList([TransformerBlock(hidden_dim, num_heads=8, attn_mode='full') for _ in range(depth)])
18
+ for block in self.blocks:
19
+ block.to(torch.float16)
20
+ else:
21
+ self.blocks = nn.ModuleList([FeedForwardNet(channels=hidden_dim, out_channels=hidden_dim, mlp_ratio=mlp_ratio) for _ in range(depth)])
22
+ self.out_layer = nn.Linear(hidden_dim, out_dim)
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = self.in_layer(x)
26
+ h = x
27
+ for block in self.blocks:
28
+ h = block(h[None].type(torch.float16))[0] if self.jp_embed_attn else block(h) + x
29
+ h = self.out_layer(h.type(x.dtype))
30
+ return h
31
+
32
+
33
+ class SkinEncoder(nn.Module):
34
+ def __init__(self, skin_feat_channels: int = 8, skl_pos_embed_freq: int = 10, jp_embedder_config: Optional[Dict[str, Any]] = {}, jp_embed_dim: int = 128, relative_pe=True, vert_feat_is_linear=True, normalize_feat=True, **kwargs):
35
+ super().__init__()
36
+ self.skin_feat_channels = skin_feat_channels
37
+ self.skl_pos_embed_freq = skl_pos_embed_freq
38
+ self.jp_embedder_config = jp_embedder_config
39
+
40
+ self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True)
41
+ self.pos_embedder_linear = nn.Linear(self.pos_embedder_fourier.out_dim, jp_embed_dim)
42
+ self.root_embedding = nn.Parameter(torch.zeros(1, jp_embed_dim))
43
+ self.joint_embedder = Embedder(in_dim=2 * jp_embed_dim, out_dim=jp_embed_dim, **self.jp_embedder_config)
44
+ self.out_layer_vert = FeedForwardNet(channels=jp_embed_dim, out_channels=skin_feat_channels)
45
+ self.out_layer_joint = FeedForwardNet(channels=jp_embed_dim, out_channels=skin_feat_channels)
46
+ self.relative_pe = relative_pe
47
+ self.vert_feat_is_linear = vert_feat_is_linear
48
+ self.normalize_feat = normalize_feat
49
+
50
+ def forward(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None):
51
+ vert_skin_embeds = [] if skin_list is not None else None
52
+ joint_skin_embeds = []
53
+ for i in range(len(joints_list)):
54
+ parent_idx = parents_list[i].clone()
55
+ joints = joints_list[i]
56
+ if self.relative_pe:
57
+ joints = joints - torch.cat([joints, joints[:1]])[parent_idx]
58
+ joints_pos_embed = self.pos_embedder_linear(self.pos_embedder_fourier(joints))
59
+ joints_pos_embed = torch.cat([joints_pos_embed, self.root_embedding], dim=0)
60
+ parents_pos_embed = joints_pos_embed[parent_idx]
61
+ jp_pos_embed = torch.cat([joints_pos_embed[:-1], parents_pos_embed], dim=-1)
62
+ joints_embed = self.joint_embedder(jp_pos_embed)
63
+ if self.normalize_feat:
64
+ joints_embed = torch.nn.functional.normalize(joints_embed, dim=-1)
65
+ if skin_list is not None:
66
+ vert_skin = skin_list[i]
67
+ if self.vert_feat_is_linear:
68
+ joints_embed_for_vert = self.out_layer_vert(joints_embed)
69
+ vert_skin_embed = vert_skin @ joints_embed_for_vert
70
+ else:
71
+ vert_skin_embed = vert_skin @ joints_embed
72
+ vert_skin_embed = self.out_layer_vert(vert_skin_embed)
73
+ vert_skin_embeds.append(vert_skin_embed)
74
+ joints_embed = self.out_layer_joint(joints_embed)
75
+ joint_skin_embeds.append(joints_embed)
76
+ return joint_skin_embeds, vert_skin_embeds
77
+
78
+
79
+ def clamp_with_grad(x, min, max):
80
+ return x + (x.clamp(min, max) - x).detach()
81
+
82
+
83
+ class TreeTransformerSkinDecoder(nn.Module):
84
+ # The principles of the tree transformer skinning model:
85
+ # 1. joint features are related to the tree structure, since the decoding process is skeleton-agnostic.
86
+ # 2. decode the skinning weights directly, hoping the transformer can handle the skinning assignment.
87
+ # It's a pure learning-based method.
88
+ def __init__(self,
89
+ skin_feat_channels: int,
90
+ model_channels: int=512,
91
+ num_heads=4,
92
+ num_blocks=4,
93
+ vert_cross_blocks_num: int = 1,
94
+ use_fp16: bool = False):
95
+ super().__init__()
96
+ self.skin_feat_channels = skin_feat_channels
97
+ self.model_channels = model_channels
98
+ self.num_heads = num_heads
99
+ self.num_blocks = num_blocks
100
+ self.root_features = nn.Parameter(torch.zeros([1, skin_feat_channels]), requires_grad=True)
101
+ self.input_layer_vertex = nn.Linear(skin_feat_channels, model_channels)
102
+ self.input_layer_skin = nn.Linear(skin_feat_channels*2, model_channels)
103
+ assert vert_cross_blocks_num <= num_blocks, f"vert_cross_blocks_num should be less than or equal to num_blocks, got {vert_cross_blocks_num} and {num_blocks}."
104
+ self.vert_cross_blocks_num = vert_cross_blocks_num
105
+ self.blocks_vertex = nn.ModuleList([
106
+ TransformerCrossBlock(
107
+ channels=model_channels,
108
+ ctx_channels=model_channels,
109
+ num_heads=num_heads,
110
+ mlp_ratio=4.0,
111
+ attn_mode="full",
112
+ no_self=True)
113
+ for _ in range(self.vert_cross_blocks_num)
114
+ ] + [
115
+ FeedForwardNet(
116
+ channels=model_channels,
117
+ mlp_ratio=4.0,
118
+ out_channels=model_channels,
119
+ )
120
+ for _ in range(num_blocks - self.vert_cross_blocks_num)
121
+ ])
122
+ self.blocks_skin = nn.ModuleList([
123
+ TransformerBlock(
124
+ channels=model_channels,
125
+ num_heads=num_heads,
126
+ mlp_ratio=4.0,
127
+ attn_mode="full")
128
+ for _ in range(num_blocks)
129
+ ])
130
+ self.out_layer_vertex = nn.Sequential(
131
+ nn.Linear(model_channels, model_channels*4),
132
+ nn.GELU(approximate="tanh"),
133
+ nn.Linear(model_channels*4, model_channels+1),
134
+ )
135
+ self.out_layer_skin = nn.Sequential(
136
+ nn.Linear(model_channels, model_channels*4),
137
+ nn.GELU(approximate="tanh"),
138
+ nn.Linear(model_channels*4, model_channels),
139
+ )
140
+ self.temp_activation = nn.ELU(alpha=1.0)
141
+ self.dtype = torch.float16 if use_fp16 else torch.float32
142
+
143
+ @property
144
+ def device(self) -> torch.device:
145
+ """
146
+ Return the device of the model.
147
+ """
148
+ return next(self.parameters()).device
149
+
150
+ def convert_to_fp16(self) -> None:
151
+ """
152
+ Convert the torso of the model to float16.
153
+ """
154
+ self.blocks_vertex.apply(convert_module_to_f16)
155
+ self.blocks_skin.apply(convert_module_to_f16)
156
+
157
+ def convert_to_fp32(self) -> None:
158
+ """
159
+ Convert the torso of the model to float32.
160
+ """
161
+ self.blocks_vertex.apply(convert_module_to_f32)
162
+ self.blocks_skin.apply(convert_module_to_f32)
163
+
164
+ def initialize_weights(self) -> None:
165
+ # Initialize transformer layers:
166
+ def _basic_init(module):
167
+ if isinstance(module, nn.Linear):
168
+ torch.nn.init.xavier_uniform_(module.weight)
169
+ if module.bias is not None:
170
+ nn.init.constant_(module.bias, 0)
171
+ self.apply(_basic_init)
172
+
173
+ def forward(self, vertex_features, joint_features, parents) -> torch.Tensor:
174
+ j_num = joint_features.shape[1]
175
+ h_v = vertex_features
176
+ h_v = self.input_layer_vertex(h_v)
177
+ h_j = joint_features
178
+ h_j = torch.cat([h_j, self.root_features[None]], dim=1)
179
+ parents = torch.where(parents < 0, torch.ones_like(parents)*j_num, parents)
180
+ h_j = torch.cat([h_j[:, :-1], h_j[:, parents[0]]], dim=-1)
181
+ h_j = self.input_layer_skin(h_j)
182
+ h_v = h_v.type(self.dtype)
183
+ h_j = h_j.type(self.dtype)
184
+ blocks_num = len(self.blocks_vertex)
185
+ for idx, block_v, block_j in zip(range(blocks_num), self.blocks_vertex, self.blocks_skin):
186
+ f_v, f_j = h_v, h_j
187
+ h_v = block_v(f_v, f_j) if idx < self.vert_cross_blocks_num else block_v(f_v)
188
+ h_j = block_j(f_j)
189
+ h_v = h_v.type(vertex_features.dtype)
190
+ h_j = h_j.type(joint_features.dtype)
191
+ h_v = self.out_layer_vertex(h_v)
192
+ h_j = self.out_layer_skin(h_j)
193
+ h_v, inv_temp = h_v[..., :-1], h_v[..., -1].unsqueeze(-1)
194
+ inv_temp = self.temp_activation(inv_temp) + self.temp_activation.alpha + 1.0
195
+ skin_weights = torch.einsum("nac,nbc->nab", h_v, h_j)
196
+ skin_weights = torch.softmax(skin_weights * inv_temp, dim=-1)
197
+ return skin_weights
198
+
199
+
200
+ SKIN_MODEL_DICT = {'tree': TreeTransformerSkinDecoder}
201
+
202
+
203
+ class SkinAutoEncoder(nn.Module):
204
+ def __init__(self, encoder_config: Dict[str, Any], decoder_config: Dict[str, Any], use_fp16: bool = False):
205
+ super().__init__()
206
+ self.skin_encoder = SkinEncoder(**encoder_config)
207
+ decoder_config['use_fp16'] = use_fp16
208
+ self.skin_decoder = SKIN_MODEL_DICT[decoder_config.pop('model_type')](**decoder_config)
209
+
210
+ self.initialize_weights()
211
+ if use_fp16:
212
+ self.convert_to_fp16()
213
+ else:
214
+ self.convert_to_fp32()
215
+
216
+ def convert_to_fp16(self) -> None:
217
+ self.skin_decoder.convert_to_fp16()
218
+
219
+ def convert_to_fp32(self) -> None:
220
+ self.skin_decoder.convert_to_fp32()
221
+
222
+ def initialize_weights(self) -> None:
223
+ # Initialize transformer layers:
224
+ def _basic_init(module):
225
+ if isinstance(module, nn.Linear):
226
+ torch.nn.init.xavier_uniform_(module.weight)
227
+ if module.bias is not None:
228
+ nn.init.constant_(module.bias, 0)
229
+ self.apply(_basic_init)
230
+
231
+ def encode(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]):
232
+ joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
233
+ return joint_skin_embeds, vert_skin_embeds
234
+
235
+ def decode(self, vertex_features, joint_features, parents) -> torch.Tensor:
236
+ skin_weights = self.skin_decoder(vertex_features, joint_features, parents)
237
+ return skin_weights
238
+
239
+ def forward(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]):
240
+ joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
241
+ skin_pred_list = []
242
+ for i in range(len(joints_list)):
243
+ skin_pred = self.skin_decoder(vert_skin_embeds[i][None], joint_skin_embeds[i][None], parents_list[i][None])
244
+ skin_pred_list.append(skin_pred[0])
245
+ return skin_pred_list, joint_skin_embeds, vert_skin_embeds
246
+
247
+
248
+ class AniGenElasticSLatEncoderGamma(SparseTransformerElasticMixin, SkinAutoEncoder):
249
+ """
250
+ SLat VAE encoder with elastic memory management.
251
+ Used for training with low VRAM.
252
+ """
anigen/modules/attention/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'flash_attn'
4
+ DEBUG = False
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global BACKEND
10
+ global DEBUG
11
+
12
+ env_attn_backend = os.environ.get('ATTN_BACKEND')
13
+ env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
+
15
+ if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
16
+ BACKEND = env_attn_backend
17
+ if env_sttn_debug is not None:
18
+ DEBUG = env_sttn_debug == '1'
19
+
20
+ print(f"[ATTENTION] Using backend: {BACKEND}")
21
+
22
+
23
+ __from_env()
24
+
25
+
26
+ def set_backend(backend: Literal['xformers', 'flash_attn']):
27
+ global BACKEND
28
+ BACKEND = backend
29
+
30
+ def set_debug(debug: bool):
31
+ global DEBUG
32
+ DEBUG = debug
33
+
34
+
35
+ from .full_attn import *
36
+ from .modules import *
anigen/modules/attention/full_attn.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from . import DEBUG, BACKEND
5
+
6
+ if BACKEND == 'xformers':
7
+ import xformers.ops as xops
8
+ elif BACKEND == 'flash_attn':
9
+ import flash_attn
10
+ elif BACKEND == 'sdpa':
11
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
12
+ elif BACKEND == 'naive':
13
+ pass
14
+ else:
15
+ raise ValueError(f"Unknown attention backend: {BACKEND}")
16
+
17
+
18
+ __all__ = [
19
+ 'scaled_dot_product_attention',
20
+ ]
21
+
22
+
23
+ def _naive_sdpa(q, k, v):
24
+ """
25
+ Naive implementation of scaled dot product attention.
26
+ """
27
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
28
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
29
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
30
+ scale_factor = 1 / math.sqrt(q.size(-1))
31
+ attn_weight = q @ k.transpose(-2, -1) * scale_factor
32
+ attn_weight = torch.softmax(attn_weight, dim=-1)
33
+ out = attn_weight @ v
34
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
35
+ return out
36
+
37
+
38
+ @overload
39
+ def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Apply scaled dot product attention.
42
+
43
+ Args:
44
+ qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
45
+ """
46
+ ...
47
+
48
+ @overload
49
+ def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Apply scaled dot product attention.
52
+
53
+ Args:
54
+ q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
55
+ kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
56
+ """
57
+ ...
58
+
59
+ @overload
60
+ def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply scaled dot product attention.
63
+
64
+ Args:
65
+ q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
66
+ k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
67
+ v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
68
+
69
+ Note:
70
+ k and v are assumed to have the same coordinate map.
71
+ """
72
+ ...
73
+
74
+ def scaled_dot_product_attention(*args, **kwargs):
75
+ arg_names_dict = {
76
+ 1: ['qkv'],
77
+ 2: ['q', 'kv'],
78
+ 3: ['q', 'k', 'v']
79
+ }
80
+ num_all_args = len(args) + len(kwargs)
81
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
82
+ for key in arg_names_dict[num_all_args][len(args):]:
83
+ assert key in kwargs, f"Missing argument {key}"
84
+
85
+ if num_all_args == 1:
86
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
87
+ assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
88
+ device = qkv.device
89
+
90
+ elif num_all_args == 2:
91
+ q = args[0] if len(args) > 0 else kwargs['q']
92
+ kv = args[1] if len(args) > 1 else kwargs['kv']
93
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
94
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
95
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
96
+ device = q.device
97
+
98
+ elif num_all_args == 3:
99
+ q = args[0] if len(args) > 0 else kwargs['q']
100
+ k = args[1] if len(args) > 1 else kwargs['k']
101
+ v = args[2] if len(args) > 2 else kwargs['v']
102
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
103
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
104
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
105
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
106
+ device = q.device
107
+
108
+ if BACKEND == 'xformers':
109
+ if num_all_args == 1:
110
+ q, k, v = qkv.unbind(dim=2)
111
+ elif num_all_args == 2:
112
+ k, v = kv.unbind(dim=2)
113
+ out = xops.memory_efficient_attention(q, k, v)
114
+ elif BACKEND == 'flash_attn':
115
+ if num_all_args == 1:
116
+ out = flash_attn.flash_attn_qkvpacked_func(qkv)
117
+ elif num_all_args == 2:
118
+ out = flash_attn.flash_attn_kvpacked_func(q, kv)
119
+ elif num_all_args == 3:
120
+ out = flash_attn.flash_attn_func(q, k, v)
121
+ elif BACKEND == 'sdpa':
122
+ if num_all_args == 1:
123
+ q, k, v = qkv.unbind(dim=2)
124
+ elif num_all_args == 2:
125
+ k, v = kv.unbind(dim=2)
126
+ q = q.permute(0, 2, 1, 3) # [N, H, L, C]
127
+ k = k.permute(0, 2, 1, 3) # [N, H, L, C]
128
+ v = v.permute(0, 2, 1, 3) # [N, H, L, C]
129
+ out = sdpa(q, k, v) # [N, H, L, C]
130
+ out = out.permute(0, 2, 1, 3) # [N, L, H, C]
131
+ elif BACKEND == 'naive':
132
+ if num_all_args == 1:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ elif num_all_args == 2:
135
+ k, v = kv.unbind(dim=2)
136
+ out = _naive_sdpa(q, k, v)
137
+ else:
138
+ raise ValueError(f"Unknown attention module: {BACKEND}")
139
+
140
+ return out
anigen/modules/attention/modules.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .full_attn import scaled_dot_product_attention
6
+
7
+
8
+ class MultiHeadRMSNorm(nn.Module):
9
+ def __init__(self, dim: int, heads: int):
10
+ super().__init__()
11
+ self.scale = dim ** 0.5
12
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
16
+
17
+
18
+ class RotaryPositionEmbedder(nn.Module):
19
+ def __init__(self, hidden_size: int, in_channels: int = 3):
20
+ super().__init__()
21
+ assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
22
+ self.hidden_size = hidden_size
23
+ self.in_channels = in_channels
24
+ self.freq_dim = hidden_size // in_channels // 2
25
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
26
+ self.freqs = 1.0 / (10000 ** self.freqs)
27
+
28
+ def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
29
+ self.freqs = self.freqs.to(indices.device)
30
+ phases = torch.outer(indices, self.freqs)
31
+ phases = torch.polar(torch.ones_like(phases), phases)
32
+ return phases
33
+
34
+ def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
35
+ x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
36
+ x_rotated = x_complex * phases
37
+ x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
38
+ return x_embed
39
+
40
+ def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Args:
43
+ q (sp.SparseTensor): [..., N, D] tensor of queries
44
+ k (sp.SparseTensor): [..., N, D] tensor of keys
45
+ indices (torch.Tensor): [..., N, C] tensor of spatial positions
46
+ """
47
+ if indices is None:
48
+ indices = torch.arange(q.shape[-2], device=q.device)
49
+ if len(q.shape) > 2:
50
+ indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
51
+
52
+ phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
53
+ if phases.shape[1] < self.hidden_size // 2:
54
+ phases = torch.cat([phases, torch.polar(
55
+ torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
56
+ torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
57
+ )], dim=-1)
58
+ q_embed = self._rotary_embedding(q, phases)
59
+ k_embed = self._rotary_embedding(k, phases)
60
+ return q_embed, k_embed
61
+
62
+
63
+ class LoRALinear(nn.Linear):
64
+ def __init__(self, in_features: int, out_features: int, bias: bool = True, rank: int = 4, lr_rate: float = 1.0):
65
+ super().__init__(in_features, out_features, bias)
66
+ self.rank = rank
67
+ self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
68
+ self.lora_B = nn.Parameter(torch.randn(rank, out_features) * 1e-2)
69
+ self.lr_rate = lr_rate
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ return super().forward(x) + (x @ self.lora_A) @ self.lora_B * self.lr_rate
72
+
73
+
74
+ class MultiHeadAttention(nn.Module):
75
+ def __init__(
76
+ self,
77
+ channels: int,
78
+ num_heads: int,
79
+ ctx_channels: Optional[int]=None,
80
+ type: Literal["self", "cross"] = "self",
81
+ attn_mode: Literal["full", "windowed"] = "full",
82
+ window_size: Optional[int] = None,
83
+ shift_window: Optional[Tuple[int, int, int]] = None,
84
+ qkv_bias: bool = True,
85
+ use_rope: bool = False,
86
+ qk_rms_norm: bool = False,
87
+ x_is_query: bool = False,
88
+ use_lora: bool = False,
89
+ lora_rank: int = 4,
90
+ lora_lr_rate: float = 1.0,
91
+ ):
92
+ super().__init__()
93
+ assert channels % num_heads == 0
94
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
95
+ assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
96
+ assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
97
+
98
+ if attn_mode == "windowed":
99
+ raise NotImplementedError("Windowed attention is not yet implemented")
100
+
101
+ self.channels = channels
102
+ self.head_dim = channels // num_heads
103
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
104
+ self.num_heads = num_heads
105
+ self._type = type
106
+ self.attn_mode = attn_mode
107
+ self.window_size = window_size
108
+ self.shift_window = shift_window
109
+ self.use_rope = use_rope
110
+ self.qk_rms_norm = qk_rms_norm
111
+
112
+ if self._type == "self":
113
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) if not use_lora else LoRALinear(channels, channels * 3, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate)
114
+ else:
115
+ self.to_q = (lambda x: x) if x_is_query else (nn.Linear(channels, channels, bias=qkv_bias) if not use_lora else LoRALinear(channels, channels, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate))
116
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) if not use_lora else LoRALinear(self.ctx_channels, channels * 2, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate)
117
+
118
+ if self.qk_rms_norm:
119
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
120
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
121
+
122
+ self.to_out = nn.Linear(channels, channels) if not use_lora else LoRALinear(channels, channels, rank=lora_rank, lr_rate=lora_lr_rate)
123
+
124
+ if use_rope:
125
+ self.rope = RotaryPositionEmbedder(channels)
126
+
127
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
128
+ B, L, C = x.shape
129
+ if self._type == "self":
130
+ qkv = self.to_qkv(x)
131
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
132
+ if self.use_rope:
133
+ q, k, v = qkv.unbind(dim=2)
134
+ q, k = self.rope(q, k, indices)
135
+ qkv = torch.stack([q, k, v], dim=2)
136
+ if self.attn_mode == "full":
137
+ if self.qk_rms_norm:
138
+ q, k, v = qkv.unbind(dim=2)
139
+ q = self.q_rms_norm(q)
140
+ k = self.k_rms_norm(k)
141
+ h = scaled_dot_product_attention(q, k, v)
142
+ else:
143
+ h = scaled_dot_product_attention(qkv)
144
+ elif self.attn_mode == "windowed":
145
+ raise NotImplementedError("Windowed attention is not yet implemented")
146
+ else:
147
+ Lkv = context.shape[1]
148
+ q = self.to_q(x)
149
+ kv = self.to_kv(context)
150
+ q = q.reshape(B, L, self.num_heads, -1)
151
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
152
+ if self.qk_rms_norm:
153
+ q = self.q_rms_norm(q)
154
+ k, v = kv.unbind(dim=2)
155
+ k = self.k_rms_norm(k)
156
+ h = scaled_dot_product_attention(q, k, v)
157
+ else:
158
+ h = scaled_dot_product_attention(q, kv)
159
+ h = h.reshape(B, L, -1)
160
+ h = self.to_out(h)
161
+ return h
anigen/modules/norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm32(nn.LayerNorm):
6
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
7
+ return super().forward(x.float()).type(x.dtype)
8
+
9
+
10
+ class GroupNorm32(nn.GroupNorm):
11
+ """
12
+ A GroupNorm layer that converts to float32 before the forward pass.
13
+ """
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return super().forward(x.float()).type(x.dtype)
16
+
17
+
18
+ class ChannelLayerNorm32(LayerNorm32):
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ DIM = x.dim()
21
+ x = x.permute(0, *range(2, DIM), 1).contiguous()
22
+ x = super().forward(x)
23
+ x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
24
+ return x
25
+
anigen/modules/sparse/__init__.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ BACKEND = 'spconv'
4
+ DEBUG = False
5
+ ATTN = 'flash_attn'
6
+
7
+ def __from_env():
8
+ import os
9
+
10
+ global BACKEND
11
+ global DEBUG
12
+ global ATTN
13
+
14
+ env_sparse_backend = os.environ.get('SPARSE_BACKEND')
15
+ env_sparse_debug = os.environ.get('SPARSE_DEBUG')
16
+ env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
17
+ if env_sparse_attn is None:
18
+ env_sparse_attn = os.environ.get('ATTN_BACKEND')
19
+
20
+ if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
21
+ BACKEND = env_sparse_backend
22
+ if env_sparse_debug is not None:
23
+ DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
25
+ ATTN = env_sparse_attn
26
+
27
+ print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
28
+
29
+
30
+ __from_env()
31
+
32
+
33
+ def set_backend(backend: Literal['spconv', 'torchsparse']):
34
+ global BACKEND
35
+ BACKEND = backend
36
+
37
+ def set_debug(debug: bool):
38
+ global DEBUG
39
+ DEBUG = debug
40
+
41
+ def set_attn(attn: Literal['xformers', 'flash_attn']):
42
+ global ATTN
43
+ ATTN = attn
44
+
45
+
46
+ import importlib
47
+
48
+ __attributes = {
49
+ 'SparseTensor': 'basic',
50
+ 'sparse_batch_broadcast': 'basic',
51
+ 'sparse_batch_op': 'basic',
52
+ 'sparse_cat': 'basic',
53
+ 'sparse_unbind': 'basic',
54
+ 'SparseGroupNorm': 'norm',
55
+ 'SparseLayerNorm': 'norm',
56
+ 'SparseGroupNorm32': 'norm',
57
+ 'SparseLayerNorm32': 'norm',
58
+ 'SparseReLU': 'nonlinearity',
59
+ 'SparseSiLU': 'nonlinearity',
60
+ 'SparseGELU': 'nonlinearity',
61
+ 'SparseActivation': 'nonlinearity',
62
+ 'SparseLinear': 'linear',
63
+ 'sparse_scaled_dot_product_attention': 'attention',
64
+ 'SerializeMode': 'attention',
65
+ 'sparse_serialized_scaled_dot_product_self_attention': 'attention',
66
+ 'sparse_windowed_scaled_dot_product_self_attention': 'attention',
67
+ 'SparseMultiHeadAttention': 'attention',
68
+ 'SparseConv3d': 'conv',
69
+ 'SparseInverseConv3d': 'conv',
70
+ 'SparseDownsample': 'spatial',
71
+ 'SparseUpsample': 'spatial',
72
+ 'SparseSubdivide' : 'spatial'
73
+ }
74
+
75
+ __submodules = ['transformer']
76
+
77
+ __all__ = list(__attributes.keys()) + __submodules
78
+
79
+ def __getattr__(name):
80
+ if name not in globals():
81
+ if name in __attributes:
82
+ module_name = __attributes[name]
83
+ module = importlib.import_module(f".{module_name}", __name__)
84
+ globals()[name] = getattr(module, name)
85
+ elif name in __submodules:
86
+ module = importlib.import_module(f".{name}", __name__)
87
+ globals()[name] = module
88
+ else:
89
+ raise AttributeError(f"module {__name__} has no attribute {name}")
90
+ return globals()[name]
91
+
92
+
93
+ # For Pylance
94
+ if __name__ == '__main__':
95
+ from .basic import *
96
+ from .norm import *
97
+ from .nonlinearity import *
98
+ from .linear import *
99
+ from .attention import *
100
+ from .conv import *
101
+ from .spatial import *
102
+ import transformer
anigen/modules/sparse/attention/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .full_attn import *
2
+ from .serialized_attn import *
3
+ from .windowed_attn import *
4
+ from .modules import *
5
+ from .windowed_attn_cross import *
anigen/modules/sparse/attention/full_attn.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ from .. import SparseTensor
4
+ from .. import DEBUG, ATTN
5
+
6
+ if ATTN == 'xformers':
7
+ import xformers.ops as xops
8
+ elif ATTN == 'flash_attn':
9
+ import flash_attn
10
+ else:
11
+ raise ValueError(f"Unknown attention module: {ATTN}")
12
+
13
+
14
+ __all__ = [
15
+ 'sparse_scaled_dot_product_attention',
16
+ ]
17
+
18
+
19
+ @overload
20
+ def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
21
+ """
22
+ Apply scaled dot product attention to a sparse tensor.
23
+
24
+ Args:
25
+ qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
26
+ """
27
+ ...
28
+
29
+ @overload
30
+ def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
31
+ """
32
+ Apply scaled dot product attention to a sparse tensor.
33
+
34
+ Args:
35
+ q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
36
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
37
+ """
38
+ ...
39
+
40
+ @overload
41
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
42
+ """
43
+ Apply scaled dot product attention to a sparse tensor.
44
+
45
+ Args:
46
+ q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
47
+ kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
48
+ """
49
+ ...
50
+
51
+ @overload
52
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
53
+ """
54
+ Apply scaled dot product attention to a sparse tensor.
55
+
56
+ Args:
57
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
58
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
59
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
60
+
61
+ Note:
62
+ k and v are assumed to have the same coordinate map.
63
+ """
64
+ ...
65
+
66
+ @overload
67
+ def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
68
+ """
69
+ Apply scaled dot product attention to a sparse tensor.
70
+
71
+ Args:
72
+ q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
73
+ k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
74
+ v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
75
+ """
76
+ ...
77
+
78
+ @overload
79
+ def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
80
+ """
81
+ Apply scaled dot product attention to a sparse tensor.
82
+
83
+ Args:
84
+ q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
85
+ k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
86
+ v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
87
+ """
88
+ ...
89
+
90
+ def sparse_scaled_dot_product_attention(*args, **kwargs):
91
+ arg_names_dict = {
92
+ 1: ['qkv'],
93
+ 2: ['q', 'kv'],
94
+ 3: ['q', 'k', 'v']
95
+ }
96
+ num_all_args = len(args) + len(kwargs)
97
+ assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
98
+ for key in arg_names_dict[num_all_args][len(args):]:
99
+ assert key in kwargs, f"Missing argument {key}"
100
+
101
+ if num_all_args == 1:
102
+ qkv = args[0] if len(args) > 0 else kwargs['qkv']
103
+ assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
104
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
105
+ device = qkv.device
106
+
107
+ s = qkv
108
+ q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
109
+ kv_seqlen = q_seqlen
110
+ qkv = qkv.feats # [T, 3, H, C]
111
+
112
+ elif num_all_args == 2:
113
+ q = args[0] if len(args) > 0 else kwargs['q']
114
+ kv = args[1] if len(args) > 1 else kwargs['kv']
115
+ assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
116
+ isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
117
+ f"Invalid types, got {type(q)} and {type(kv)}"
118
+ assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
119
+ device = q.device
120
+
121
+ if isinstance(q, SparseTensor):
122
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
123
+ s = q
124
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
125
+ q = q.feats # [T_Q, H, C]
126
+ else:
127
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
128
+ s = None
129
+ N, L, H, C = q.shape
130
+ q_seqlen = [L] * N
131
+ q = q.reshape(N * L, H, C) # [T_Q, H, C]
132
+
133
+ if isinstance(kv, SparseTensor):
134
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
135
+ kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
136
+ kv = kv.feats # [T_KV, 2, H, C]
137
+ else:
138
+ assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
139
+ N, L, _, H, C = kv.shape
140
+ kv_seqlen = [L] * N
141
+ kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
142
+
143
+ elif num_all_args == 3:
144
+ q = args[0] if len(args) > 0 else kwargs['q']
145
+ k = args[1] if len(args) > 1 else kwargs['k']
146
+ v = args[2] if len(args) > 2 else kwargs['v']
147
+ assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
148
+ isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
149
+ f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
150
+ assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
151
+ device = q.device
152
+
153
+ if isinstance(q, SparseTensor):
154
+ assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
155
+ s = q
156
+ q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
157
+ q = q.feats # [T_Q, H, Ci]
158
+ else:
159
+ assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
160
+ s = None
161
+ N, L, H, CI = q.shape
162
+ q_seqlen = [L] * N
163
+ q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
164
+
165
+ if isinstance(k, SparseTensor):
166
+ assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
167
+ assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
168
+ kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
169
+ k = k.feats # [T_KV, H, Ci]
170
+ v = v.feats # [T_KV, H, Co]
171
+ else:
172
+ assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
173
+ assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
174
+ N, L, H, CI, CO = *k.shape, v.shape[-1]
175
+ kv_seqlen = [L] * N
176
+ k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
177
+ v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
178
+
179
+ if DEBUG:
180
+ if s is not None:
181
+ for i in range(s.shape[0]):
182
+ assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
183
+ if num_all_args in [2, 3]:
184
+ assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
185
+ if num_all_args == 3:
186
+ assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
187
+ assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
188
+
189
+ if ATTN == 'xformers':
190
+ if num_all_args == 1:
191
+ q, k, v = qkv.unbind(dim=1)
192
+ elif num_all_args == 2:
193
+ k, v = kv.unbind(dim=1)
194
+ q = q.unsqueeze(0)
195
+ k = k.unsqueeze(0)
196
+ v = v.unsqueeze(0)
197
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
198
+ out = xops.memory_efficient_attention(q, k, v, mask)[0]
199
+ elif ATTN == 'flash_attn':
200
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
201
+ if num_all_args in [2, 3]:
202
+ cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
203
+ if num_all_args == 1:
204
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
205
+ elif num_all_args == 2:
206
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
207
+ elif num_all_args == 3:
208
+ out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
209
+ else:
210
+ raise ValueError(f"Unknown attention module: {ATTN}")
211
+
212
+ if s is not None:
213
+ return s.replace(out)
214
+ else:
215
+ return out.reshape(N, L, H, -1)
anigen/modules/sparse/attention/modules.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .. import SparseTensor
6
+ from .full_attn import sparse_scaled_dot_product_attention
7
+ from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
8
+ from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
9
+ from .windowed_attn_cross import sparse_windowed_scaled_dot_product_cross_attention
10
+ from ...attention import RotaryPositionEmbedder
11
+
12
+
13
+ class SparseMultiHeadRMSNorm(nn.Module):
14
+ def __init__(self, dim: int, heads: int):
15
+ super().__init__()
16
+ self.scale = dim ** 0.5
17
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
18
+
19
+ def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
20
+ x_type = x.dtype
21
+ x = x.float()
22
+ if isinstance(x, SparseTensor):
23
+ x = x.replace(F.normalize(x.feats, dim=-1))
24
+ else:
25
+ x = F.normalize(x, dim=-1)
26
+ return (x * self.gamma * self.scale).to(x_type)
27
+
28
+
29
+ class SparseMultiHeadAttention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ channels: int,
33
+ num_heads: int,
34
+ ctx_channels: Optional[int] = None,
35
+ type: Literal["self", "cross"] = "self",
36
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
37
+ window_size: Optional[int] = None,
38
+ shift_sequence: Optional[int] = None,
39
+ shift_window: Optional[Tuple[int, int, int]] = None,
40
+ serialize_mode: Optional[SerializeMode] = None,
41
+ qkv_bias: bool = True,
42
+ use_rope: bool = False,
43
+ qk_rms_norm: bool = False,
44
+ cross_attn_cache_suffix: str = '',
45
+ ):
46
+ super().__init__()
47
+ assert channels % num_heads == 0
48
+ assert type in ["self", "cross"], f"Invalid attention type: {type}"
49
+ assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
50
+ assert type == "self" or (attn_mode == "full" or attn_mode == "windowed"), "Cross-attention only supports full and windowed attention"
51
+ assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
52
+ self.channels = channels
53
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
54
+ self.num_heads = num_heads
55
+ self._type = type
56
+ self.attn_mode = attn_mode
57
+ self.window_size = window_size
58
+ self.shift_sequence = shift_sequence
59
+ self.shift_window = shift_window
60
+ self.serialize_mode = serialize_mode
61
+ self.use_rope = use_rope
62
+ self.qk_rms_norm = qk_rms_norm
63
+
64
+ if self._type == "self":
65
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
66
+ else:
67
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
68
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
69
+ self.cross_attn_cache_suffix = cross_attn_cache_suffix
70
+
71
+ if self.qk_rms_norm:
72
+ self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
73
+ self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
74
+
75
+ self.to_out = nn.Linear(channels, channels)
76
+
77
+ if use_rope:
78
+ self.rope = RotaryPositionEmbedder(channels)
79
+
80
+ @staticmethod
81
+ def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
82
+ if isinstance(x, SparseTensor):
83
+ return x.replace(module(x.feats))
84
+ else:
85
+ return module(x)
86
+
87
+ @staticmethod
88
+ def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
89
+ if isinstance(x, SparseTensor):
90
+ return x.reshape(*shape)
91
+ else:
92
+ return x.reshape(*x.shape[:2], *shape)
93
+
94
+ def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
95
+ if isinstance(x, SparseTensor):
96
+ x_feats = x.feats.unsqueeze(0)
97
+ else:
98
+ x_feats = x
99
+ x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
100
+ return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
101
+
102
+ def _rope(self, qkv: SparseTensor) -> SparseTensor:
103
+ q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
104
+ q, k = self.rope(q, k, qkv.coords[:, 1:])
105
+ qkv = qkv.replace(torch.stack([q, k, v], dim=1))
106
+ return qkv
107
+
108
+ def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
109
+ if self._type == "self":
110
+ qkv = self._linear(self.to_qkv, x)
111
+ qkv = self._fused_pre(qkv, num_fused=3)
112
+ if self.use_rope:
113
+ qkv = self._rope(qkv)
114
+ if self.qk_rms_norm:
115
+ q, k, v = qkv.unbind(dim=1)
116
+ q = self.q_rms_norm(q)
117
+ k = self.k_rms_norm(k)
118
+ qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
119
+ if self.attn_mode == "full":
120
+ h = sparse_scaled_dot_product_attention(qkv)
121
+ elif self.attn_mode == "serialized":
122
+ h = sparse_serialized_scaled_dot_product_self_attention(
123
+ qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
124
+ )
125
+ elif self.attn_mode == "windowed":
126
+ h = sparse_windowed_scaled_dot_product_self_attention(
127
+ qkv, self.window_size, shift_window=self.shift_window
128
+ )
129
+ else:
130
+ q = self._linear(self.to_q, x)
131
+ q = self._reshape_chs(q, (self.num_heads, -1))
132
+ kv = self._linear(self.to_kv, context)
133
+ kv = self._fused_pre(kv, num_fused=2)
134
+ if self.qk_rms_norm:
135
+ q = self.q_rms_norm(q)
136
+ k, v = kv.unbind(dim=1)
137
+ k = self.k_rms_norm(k)
138
+ kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
139
+ if self.attn_mode == "full":
140
+ h = sparse_scaled_dot_product_attention(q, kv)
141
+ elif self.attn_mode == "windowed":
142
+ q = self._fused_pre(q, num_fused=1)
143
+ h = sparse_windowed_scaled_dot_product_cross_attention(
144
+ q, kv, self.window_size, shift_window=self.shift_window,
145
+ cache_suffix=self.cross_attn_cache_suffix
146
+ )
147
+ elif self.attn_mode == "serialized":
148
+ raise NotImplementedError("Serialized attention is not supported for cross-attention")
149
+ h = self._reshape_chs(h, (-1,))
150
+ h = self._linear(self.to_out, h)
151
+ return h
anigen/modules/sparse/attention/serialized_attn.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from enum import Enum
3
+ import torch
4
+ import math
5
+ from .. import SparseTensor
6
+ from .. import DEBUG, ATTN
7
+
8
+ if ATTN == 'xformers':
9
+ import xformers.ops as xops
10
+ elif ATTN == 'flash_attn':
11
+ import flash_attn
12
+ else:
13
+ raise ValueError(f"Unknown attention module: {ATTN}")
14
+
15
+
16
+ __all__ = [
17
+ 'sparse_serialized_scaled_dot_product_self_attention',
18
+ ]
19
+
20
+
21
+ class SerializeMode(Enum):
22
+ Z_ORDER = 0
23
+ Z_ORDER_TRANSPOSED = 1
24
+ HILBERT = 2
25
+ HILBERT_TRANSPOSED = 3
26
+
27
+
28
+ SerializeModes = [
29
+ SerializeMode.Z_ORDER,
30
+ SerializeMode.Z_ORDER_TRANSPOSED,
31
+ SerializeMode.HILBERT,
32
+ SerializeMode.HILBERT_TRANSPOSED
33
+ ]
34
+
35
+
36
+ def calc_serialization(
37
+ tensor: SparseTensor,
38
+ window_size: int,
39
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
40
+ shift_sequence: int = 0,
41
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
43
+ """
44
+ Calculate serialization and partitioning for a set of coordinates.
45
+
46
+ Args:
47
+ tensor (SparseTensor): The input tensor.
48
+ window_size (int): The window size to use.
49
+ serialize_mode (SerializeMode): The serialization mode to use.
50
+ shift_sequence (int): The shift of serialized sequence.
51
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
52
+
53
+ Returns:
54
+ (torch.Tensor, torch.Tensor): Forwards and backwards indices.
55
+ """
56
+ fwd_indices = []
57
+ bwd_indices = []
58
+ seq_lens = []
59
+ seq_batch_indices = []
60
+ offsets = [0]
61
+
62
+ if 'vox2seq' not in globals():
63
+ import vox2seq
64
+
65
+ # Serialize the input
66
+ serialize_coords = tensor.coords[:, 1:].clone()
67
+ serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
68
+ if serialize_mode == SerializeMode.Z_ORDER:
69
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
70
+ elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
71
+ code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
72
+ elif serialize_mode == SerializeMode.HILBERT:
73
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
74
+ elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
75
+ code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
76
+ else:
77
+ raise ValueError(f"Unknown serialize mode: {serialize_mode}")
78
+
79
+ for bi, s in enumerate(tensor.layout):
80
+ num_points = s.stop - s.start
81
+ num_windows = (num_points + window_size - 1) // window_size
82
+ valid_window_size = num_points / num_windows
83
+ to_ordered = torch.argsort(code[s.start:s.stop])
84
+ if num_windows == 1:
85
+ fwd_indices.append(to_ordered)
86
+ bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
87
+ fwd_indices[-1] += s.start
88
+ bwd_indices[-1] += offsets[-1]
89
+ seq_lens.append(num_points)
90
+ seq_batch_indices.append(bi)
91
+ offsets.append(offsets[-1] + seq_lens[-1])
92
+ else:
93
+ # Partition the input
94
+ offset = 0
95
+ mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
96
+ split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
97
+ bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
98
+ for i in range(num_windows):
99
+ mid = mids[i]
100
+ valid_start = split[i]
101
+ valid_end = split[i + 1]
102
+ padded_start = math.floor(mid - 0.5 * window_size)
103
+ padded_end = padded_start + window_size
104
+ fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
105
+ offset += valid_start - padded_start
106
+ bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
107
+ offset += padded_end - valid_start
108
+ fwd_indices[-1] += s.start
109
+ seq_lens.extend([window_size] * num_windows)
110
+ seq_batch_indices.extend([bi] * num_windows)
111
+ bwd_indices.append(bwd_index + offsets[-1])
112
+ offsets.append(offsets[-1] + num_windows * window_size)
113
+
114
+ fwd_indices = torch.cat(fwd_indices)
115
+ bwd_indices = torch.cat(bwd_indices)
116
+
117
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
118
+
119
+
120
+ def sparse_serialized_scaled_dot_product_self_attention(
121
+ qkv: SparseTensor,
122
+ window_size: int,
123
+ serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
124
+ shift_sequence: int = 0,
125
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
126
+ ) -> SparseTensor:
127
+ """
128
+ Apply serialized scaled dot product self attention to a sparse tensor.
129
+
130
+ Args:
131
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
132
+ window_size (int): The window size to use.
133
+ serialize_mode (SerializeMode): The serialization mode to use.
134
+ shift_sequence (int): The shift of serialized sequence.
135
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
136
+ shift (int): The shift to use.
137
+ """
138
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
139
+
140
+ serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
141
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
142
+ if serialization_spatial_cache is None:
143
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
144
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
145
+ else:
146
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
147
+
148
+ M = fwd_indices.shape[0]
149
+ T = qkv.feats.shape[0]
150
+ H = qkv.feats.shape[2]
151
+ C = qkv.feats.shape[3]
152
+
153
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
154
+
155
+ if DEBUG:
156
+ start = 0
157
+ qkv_coords = qkv.coords[fwd_indices]
158
+ for i in range(len(seq_lens)):
159
+ assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
160
+ start += seq_lens[i]
161
+
162
+ if all([seq_len == window_size for seq_len in seq_lens]):
163
+ B = len(seq_lens)
164
+ N = window_size
165
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
166
+ if ATTN == 'xformers':
167
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
168
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
+ elif ATTN == 'flash_attn':
170
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
171
+ else:
172
+ raise ValueError(f"Unknown attention module: {ATTN}")
173
+ out = out.reshape(B * N, H, C) # [M, H, C]
174
+ else:
175
+ if ATTN == 'xformers':
176
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
177
+ q = q.unsqueeze(0) # [1, M, H, C]
178
+ k = k.unsqueeze(0) # [1, M, H, C]
179
+ v = v.unsqueeze(0) # [1, M, H, C]
180
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
181
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
182
+ elif ATTN == 'flash_attn':
183
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
+ .to(qkv.device).int()
185
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
186
+
187
+ out = out[bwd_indices] # [T, H, C]
188
+
189
+ if DEBUG:
190
+ qkv_coords = qkv_coords[bwd_indices]
191
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
192
+
193
+ return qkv.replace(out)
anigen/modules/sparse/attention/windowed_attn.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from .. import SparseTensor
5
+ from .. import DEBUG, ATTN
6
+
7
+ if ATTN == 'xformers':
8
+ import xformers.ops as xops
9
+ elif ATTN == 'flash_attn':
10
+ import flash_attn
11
+ else:
12
+ raise ValueError(f"Unknown attention module: {ATTN}")
13
+
14
+
15
+ __all__ = [
16
+ 'sparse_windowed_scaled_dot_product_self_attention',
17
+ ]
18
+
19
+
20
+ def calc_window_partition(
21
+ tensor: SparseTensor,
22
+ window_size: Union[int, Tuple[int, ...]],
23
+ shift_window: Union[int, Tuple[int, ...]] = 0
24
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
25
+ """
26
+ Calculate serialization and partitioning for a set of coordinates.
27
+
28
+ Args:
29
+ tensor (SparseTensor): The input tensor.
30
+ window_size (int): The window size to use.
31
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
32
+
33
+ Returns:
34
+ (torch.Tensor): Forwards indices.
35
+ (torch.Tensor): Backwards indices.
36
+ (List[int]): Sequence lengths.
37
+ (List[int]): Sequence batch indices.
38
+ """
39
+ DIM = tensor.coords.shape[1] - 1
40
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
41
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
42
+ shifted_coords = tensor.coords.clone().detach()
43
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
44
+
45
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
46
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
47
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
48
+
49
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
50
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
51
+ fwd_indices = torch.argsort(shifted_indices)
52
+ bwd_indices = torch.empty_like(fwd_indices)
53
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
54
+ seq_lens = torch.bincount(shifted_indices)
55
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
56
+ mask = seq_lens != 0
57
+ seq_lens = seq_lens[mask].tolist()
58
+ seq_batch_indices = seq_batch_indices[mask].tolist()
59
+
60
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
61
+
62
+
63
+ def sparse_windowed_scaled_dot_product_self_attention(
64
+ qkv: SparseTensor,
65
+ window_size: int,
66
+ shift_window: Tuple[int, int, int] = (0, 0, 0)
67
+ ) -> SparseTensor:
68
+ """
69
+ Apply windowed scaled dot product self attention to a sparse tensor.
70
+
71
+ Args:
72
+ qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
73
+ window_size (int): The window size to use.
74
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
75
+ shift (int): The shift to use.
76
+ """
77
+ assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
78
+
79
+ serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
80
+ serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
81
+ if serialization_spatial_cache is None:
82
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
83
+ qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
84
+ else:
85
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
86
+
87
+ M = fwd_indices.shape[0]
88
+ T = qkv.feats.shape[0]
89
+ H = qkv.feats.shape[2]
90
+ C = qkv.feats.shape[3]
91
+
92
+ qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
93
+
94
+ if DEBUG:
95
+ start = 0
96
+ qkv_coords = qkv.coords[fwd_indices]
97
+ for i in range(len(seq_lens)):
98
+ seq_coords = qkv_coords[start:start+seq_lens[i]]
99
+ assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
100
+ assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
101
+ f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
102
+ start += seq_lens[i]
103
+
104
+ if all([seq_len == window_size for seq_len in seq_lens]):
105
+ B = len(seq_lens)
106
+ N = window_size
107
+ qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
108
+ if ATTN == 'xformers':
109
+ q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
110
+ out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
111
+ elif ATTN == 'flash_attn':
112
+ out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
113
+ else:
114
+ raise ValueError(f"Unknown attention module: {ATTN}")
115
+ out = out.reshape(B * N, H, C) # [M, H, C]
116
+ else:
117
+ if ATTN == 'xformers':
118
+ q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
119
+ q = q.unsqueeze(0) # [1, M, H, C]
120
+ k = k.unsqueeze(0) # [1, M, H, C]
121
+ v = v.unsqueeze(0) # [1, M, H, C]
122
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
123
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
124
+ elif ATTN == 'flash_attn':
125
+ cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
126
+ .to(qkv.device).int()
127
+ out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
128
+
129
+ out = out[bwd_indices] # [T, H, C]
130
+
131
+ if DEBUG:
132
+ qkv_coords = qkv_coords[bwd_indices]
133
+ assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
134
+
135
+ return qkv.replace(out)
anigen/modules/sparse/attention/windowed_attn_cross.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import math
4
+ from .. import SparseTensor
5
+ from .. import DEBUG, ATTN
6
+
7
+ if ATTN == 'xformers':
8
+ import xformers.ops as xops
9
+ elif ATTN == 'flash_attn':
10
+ import flash_attn
11
+ else:
12
+ raise ValueError(f"Unknown attention module: {ATTN}")
13
+
14
+
15
+ __all__ = [
16
+ 'sparse_windowed_scaled_dot_product_cross_attention',
17
+ ]
18
+
19
+
20
+ def calc_window_partition_cross(
21
+ tensor: SparseTensor,
22
+ context: SparseTensor,
23
+ window_size: Union[int, Tuple[int, ...]],
24
+ shift_window: Union[int, Tuple[int, ...]] = 0
25
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
26
+ """
27
+ Calculate serialization and partitioning for a set of coordinates.
28
+
29
+ Args:
30
+ tensor (SparseTensor): The input tensor.
31
+ window_size (int): The window size to use.
32
+ shift_window (Tuple[int, ...]): The shift of serialized coordinates.
33
+
34
+ Returns:
35
+ (torch.Tensor): Forwards indices.
36
+ (torch.Tensor): Backwards indices.
37
+ (List[int]): Sequence lengths.
38
+ (List[int]): Sequence batch indices.
39
+ """
40
+
41
+ def calc_window_partition_(tensor, window_size, shift_window):
42
+ DIM = tensor.coords.shape[1] - 1
43
+ shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
44
+ window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
45
+ shifted_coords = tensor.coords.clone().detach()
46
+ shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
47
+
48
+ MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
49
+ NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
50
+ OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
51
+
52
+ shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
53
+ shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
54
+ fwd_indices = torch.argsort(shifted_indices)
55
+ bwd_indices = torch.empty_like(fwd_indices)
56
+ bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
57
+ seq_lens = torch.bincount(shifted_indices)
58
+ seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
59
+ return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
60
+
61
+ fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition_(tensor, window_size, shift_window)
62
+ fwd_indices_context, bwd_indices_context, seq_lens_context, seq_batch_indices_context = calc_window_partition_(context, window_size, shift_window)
63
+ # Pad the shorter one to the shape of the other with 0 tail
64
+ max_len = max(seq_lens.shape[0], seq_lens_context.shape[0])
65
+ if seq_lens.shape[0] < max_len:
66
+ pad_size = max_len - seq_lens.shape[0]
67
+ seq_lens = torch.cat([seq_lens, torch.zeros(pad_size, dtype=seq_lens.dtype, device=seq_lens.device)])
68
+ if seq_lens_context.shape[0] < max_len:
69
+ pad_size = max_len - seq_lens_context.shape[0]
70
+ seq_lens_context = torch.cat([seq_lens_context, torch.zeros(pad_size, dtype=seq_lens_context.dtype, device=seq_lens_context.device)])
71
+ mask = (seq_lens != 0) | (seq_lens_context != 0)
72
+ seq_lens = seq_lens[mask].tolist()
73
+ seq_lens_context = seq_lens_context[mask].tolist()
74
+
75
+ return fwd_indices, bwd_indices, seq_lens, fwd_indices_context, bwd_indices_context, seq_lens_context
76
+
77
+
78
+ def sparse_windowed_scaled_dot_product_cross_attention(
79
+ q: SparseTensor,
80
+ kv: SparseTensor,
81
+ window_size: int,
82
+ shift_window: Tuple[int, int, int] = (0, 0, 0),
83
+ cache_suffix: str = '',
84
+ ) -> SparseTensor:
85
+ """
86
+ Apply windowed scaled dot product cross attention to a sparse tensor.
87
+
88
+ Args:
89
+ q, kv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
90
+ window_size (int): The window size to use.
91
+ shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
92
+ shift (int): The shift to use.
93
+ """
94
+ assert len(q.shape) == 4 and q.shape[1] == 1, f"Invalid shape for q, got {q.shape}, expected [N, *, 1, H, C]"
95
+ assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
96
+
97
+ serialization_spatial_cache_name_q = f'window_partition_{window_size}_{shift_window}_cross_q' + cache_suffix
98
+ serialization_spatial_cache_q = q.get_spatial_cache(serialization_spatial_cache_name_q)
99
+ serialization_spatial_cache_name_kv = f'window_partition_{window_size}_{shift_window}_cross_kv' + cache_suffix
100
+ serialization_spatial_cache_kv = kv.get_spatial_cache(serialization_spatial_cache_name_kv)
101
+ if serialization_spatial_cache_q is None or serialization_spatial_cache_kv is None:
102
+ q_fwd_indices, q_bwd_indices, q_seq_lens, kv_fwd_indices, kv_bwd_indices, kv_seq_lens = calc_window_partition_cross(q, kv, window_size, shift_window)
103
+ q.register_spatial_cache(serialization_spatial_cache_name_q, (q_fwd_indices, q_bwd_indices, q_seq_lens))
104
+ kv.register_spatial_cache(serialization_spatial_cache_name_kv, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens))
105
+ else:
106
+ kv_fwd_indices, kv_bwd_indices, kv_seq_lens = serialization_spatial_cache_kv
107
+ q_fwd_indices, q_bwd_indices, q_seq_lens = serialization_spatial_cache_q
108
+
109
+ M_q, T_q, H_q, C_q = q_fwd_indices.shape[0], q.feats.shape[0], q.feats.shape[2], q.feats.shape[3]
110
+ M_kv, T_kv, H_kv, C_kv = kv_fwd_indices.shape[0], kv.feats.shape[0], kv.feats.shape[2], kv.feats.shape[3]
111
+ assert (H_q == H_kv and C_q == C_kv), \
112
+ f"Mismatch in shapes: q ({M_q}, {T_q}, {H_q}, {C_q}), kv ({M_kv}, {T_kv}, {H_kv}, {C_kv})"
113
+
114
+ q_feats = q.feats[q_fwd_indices] # [M, 1, H, C]
115
+ kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C]
116
+
117
+ if ATTN == 'xformers':
118
+ q, k, v = q_feats[:, 0], kv_feats.unbind(dim=1) # [M, H, C]
119
+ q = q.unsqueeze(0) # [1, M, H, C]
120
+ k = k.unsqueeze(0) # [1, M, H, C]
121
+ v = v.unsqueeze(0) # [1, M, H, C]
122
+ mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens)
123
+ out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
124
+ elif ATTN == 'flash_attn':
125
+ cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seq_lens), dim=0)], dim=0).to(q.device).int()
126
+ cu_seqlens_k = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seq_lens), dim=0)], dim=0).to(kv.device).int()
127
+ out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats[:, 0], kv_feats, cu_seqlens_q, cu_seqlens_k, max(q_seq_lens), max(kv_seq_lens)) # [M, H, C]
128
+
129
+ out = out[q_bwd_indices] # [T, H, C]
130
+ return q.replace(out)
131
+
anigen/modules/sparse/basic.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from . import BACKEND, DEBUG
5
+ SparseTensorData = None # Lazy import
6
+
7
+ import importlib
8
+ if BACKEND == 'torchsparse':
9
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
10
+ elif BACKEND == 'spconv':
11
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
12
+
13
+
14
+ __all__ = [
15
+ 'SparseTensor',
16
+ 'sparse_batch_broadcast',
17
+ 'sparse_batch_op',
18
+ 'sparse_cat',
19
+ 'sparse_unbind',
20
+ ]
21
+
22
+
23
+ class SparseTensor:
24
+ """
25
+ Sparse tensor with support for both torchsparse and spconv backends.
26
+
27
+ Parameters:
28
+ - feats (torch.Tensor): Features of the sparse tensor.
29
+ - coords (torch.Tensor): Coordinates of the sparse tensor.
30
+ - shape (torch.Size): Shape of the sparse tensor.
31
+ - layout (List[slice]): Layout of the sparse tensor for each batch
32
+ - data (SparseTensorData): Sparse tensor data used for convolusion
33
+
34
+ NOTE:
35
+ - Data corresponding to a same batch should be contiguous.
36
+ - Coords should be in [0, 1023]
37
+ """
38
+ @overload
39
+ def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
40
+
41
+ @overload
42
+ def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
43
+
44
+ def __init__(self, *args, **kwargs):
45
+ # Lazy import of sparse tensor backend
46
+ global SparseTensorData
47
+ if SparseTensorData is None:
48
+ import importlib
49
+ if BACKEND == 'torchsparse':
50
+ SparseTensorData = importlib.import_module('torchsparse').SparseTensor
51
+ elif BACKEND == 'spconv':
52
+ SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
53
+
54
+ method_id = 0
55
+ if len(args) != 0:
56
+ method_id = 0 if isinstance(args[0], torch.Tensor) else 1
57
+ else:
58
+ method_id = 1 if 'data' in kwargs else 0
59
+
60
+ if method_id == 0:
61
+ feats, coords, shape, layout = args + (None,) * (4 - len(args))
62
+ if 'feats' in kwargs:
63
+ feats = kwargs['feats']
64
+ del kwargs['feats']
65
+ if 'coords' in kwargs:
66
+ coords = kwargs['coords']
67
+ del kwargs['coords']
68
+ if 'shape' in kwargs:
69
+ shape = kwargs['shape']
70
+ del kwargs['shape']
71
+ if 'layout' in kwargs:
72
+ layout = kwargs['layout']
73
+ del kwargs['layout']
74
+
75
+ if shape is None:
76
+ shape = self.__cal_shape(feats, coords)
77
+ if layout is None:
78
+ layout = self.__cal_layout(coords, shape[0])
79
+ if BACKEND == 'torchsparse':
80
+ self.data = SparseTensorData(feats, coords, **kwargs)
81
+ elif BACKEND == 'spconv':
82
+ spatial_shape = list(coords.max(0)[0] + 1)[1:]
83
+ self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
84
+ self.data._features = feats
85
+ elif method_id == 1:
86
+ data, shape, layout = args + (None,) * (3 - len(args))
87
+ if 'data' in kwargs:
88
+ data = kwargs['data']
89
+ del kwargs['data']
90
+ if 'shape' in kwargs:
91
+ shape = kwargs['shape']
92
+ del kwargs['shape']
93
+ if 'layout' in kwargs:
94
+ layout = kwargs['layout']
95
+ del kwargs['layout']
96
+
97
+ self.data = data
98
+ if shape is None:
99
+ shape = self.__cal_shape(self.feats, self.coords)
100
+ if layout is None:
101
+ layout = self.__cal_layout(self.coords, shape[0])
102
+
103
+ self._shape = shape
104
+ self._layout = layout
105
+ self._scale = kwargs.get('scale', (1, 1, 1))
106
+ self._spatial_cache = kwargs.get('spatial_cache', {})
107
+
108
+ if DEBUG:
109
+ try:
110
+ assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
111
+ assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
112
+ assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
113
+ for i in range(self.shape[0]):
114
+ assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
115
+ except Exception as e:
116
+ print('Debugging information:')
117
+ print(f"- Shape: {self.shape}")
118
+ print(f"- Layout: {self.layout}")
119
+ print(f"- Scale: {self._scale}")
120
+ print(f"- Coords: {self.coords}")
121
+ raise e
122
+
123
+ def __cal_shape(self, feats, coords):
124
+ shape = []
125
+ shape.append(coords[:, 0].max().item() + 1)
126
+ shape.extend([*feats.shape[1:]])
127
+ return torch.Size(shape)
128
+
129
+ def __cal_layout(self, coords, batch_size):
130
+ seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
131
+ offset = torch.cumsum(seq_len, dim=0)
132
+ layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
133
+ return layout
134
+
135
+ @property
136
+ def shape(self) -> torch.Size:
137
+ return self._shape
138
+
139
+ def dim(self) -> int:
140
+ return len(self.shape)
141
+
142
+ @property
143
+ def layout(self) -> List[slice]:
144
+ return self._layout
145
+
146
+ @property
147
+ def feats(self) -> torch.Tensor:
148
+ if BACKEND == 'torchsparse':
149
+ return self.data.F
150
+ elif BACKEND == 'spconv':
151
+ return self.data.features
152
+
153
+ @feats.setter
154
+ def feats(self, value: torch.Tensor):
155
+ if BACKEND == 'torchsparse':
156
+ self.data.F = value
157
+ elif BACKEND == 'spconv':
158
+ self.data.features = value
159
+
160
+ @property
161
+ def coords(self) -> torch.Tensor:
162
+ if BACKEND == 'torchsparse':
163
+ return self.data.C
164
+ elif BACKEND == 'spconv':
165
+ return self.data.indices
166
+
167
+ @coords.setter
168
+ def coords(self, value: torch.Tensor):
169
+ if BACKEND == 'torchsparse':
170
+ self.data.C = value
171
+ elif BACKEND == 'spconv':
172
+ self.data.indices = value
173
+
174
+ @property
175
+ def dtype(self):
176
+ return self.feats.dtype
177
+
178
+ @property
179
+ def device(self):
180
+ return self.feats.device
181
+
182
+ @overload
183
+ def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
184
+
185
+ @overload
186
+ def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
187
+
188
+ def to(self, *args, **kwargs) -> 'SparseTensor':
189
+ device = None
190
+ dtype = None
191
+ if len(args) == 2:
192
+ device, dtype = args
193
+ elif len(args) == 1:
194
+ if isinstance(args[0], torch.dtype):
195
+ dtype = args[0]
196
+ else:
197
+ device = args[0]
198
+ if 'dtype' in kwargs:
199
+ assert dtype is None, "to() received multiple values for argument 'dtype'"
200
+ dtype = kwargs['dtype']
201
+ if 'device' in kwargs:
202
+ assert device is None, "to() received multiple values for argument 'device'"
203
+ device = kwargs['device']
204
+
205
+ new_feats = self.feats.to(device=device, dtype=dtype)
206
+ new_coords = self.coords.to(device=device)
207
+ return self.replace(new_feats, new_coords)
208
+
209
+ def type(self, dtype):
210
+ new_feats = self.feats.type(dtype)
211
+ return self.replace(new_feats)
212
+
213
+ def cpu(self) -> 'SparseTensor':
214
+ new_feats = self.feats.cpu()
215
+ new_coords = self.coords.cpu()
216
+ return self.replace(new_feats, new_coords)
217
+
218
+ def cuda(self) -> 'SparseTensor':
219
+ new_feats = self.feats.cuda()
220
+ new_coords = self.coords.cuda()
221
+ return self.replace(new_feats, new_coords)
222
+
223
+ def half(self) -> 'SparseTensor':
224
+ new_feats = self.feats.half()
225
+ return self.replace(new_feats)
226
+
227
+ def float(self) -> 'SparseTensor':
228
+ new_feats = self.feats.float()
229
+ return self.replace(new_feats)
230
+
231
+ def detach(self) -> 'SparseTensor':
232
+ new_coords = self.coords.detach()
233
+ new_feats = self.feats.detach()
234
+ return self.replace(new_feats, new_coords)
235
+
236
+ def dense(self) -> torch.Tensor:
237
+ if BACKEND == 'torchsparse':
238
+ return self.data.dense()
239
+ elif BACKEND == 'spconv':
240
+ return self.data.dense()
241
+
242
+ def reshape(self, *shape) -> 'SparseTensor':
243
+ new_feats = self.feats.reshape(self.feats.shape[0], *shape)
244
+ return self.replace(new_feats)
245
+
246
+ def unbind(self, dim: int) -> List['SparseTensor']:
247
+ return sparse_unbind(self, dim)
248
+
249
+ def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
250
+ new_shape = [self.shape[0]]
251
+ new_shape.extend(feats.shape[1:])
252
+ if BACKEND == 'torchsparse':
253
+ new_data = SparseTensorData(
254
+ feats=feats,
255
+ coords=self.data.coords if coords is None else coords,
256
+ stride=self.data.stride,
257
+ spatial_range=self.data.spatial_range,
258
+ )
259
+ new_data._caches = self.data._caches
260
+ elif BACKEND == 'spconv':
261
+ new_data = SparseTensorData(
262
+ self.data.features.reshape(self.data.features.shape[0], -1),
263
+ self.data.indices,
264
+ self.data.spatial_shape,
265
+ self.data.batch_size,
266
+ self.data.grid,
267
+ self.data.voxel_num,
268
+ self.data.indice_dict
269
+ )
270
+ new_data._features = feats
271
+ new_data.benchmark = self.data.benchmark
272
+ new_data.benchmark_record = self.data.benchmark_record
273
+ new_data.thrust_allocator = self.data.thrust_allocator
274
+ new_data._timer = self.data._timer
275
+ new_data.force_algo = self.data.force_algo
276
+ new_data.int8_scale = self.data.int8_scale
277
+ if coords is not None:
278
+ new_data.indices = coords
279
+ new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
280
+ return new_tensor
281
+
282
+ @staticmethod
283
+ def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
284
+ N, C = dim
285
+ x = torch.arange(aabb[0], aabb[3] + 1)
286
+ y = torch.arange(aabb[1], aabb[4] + 1)
287
+ z = torch.arange(aabb[2], aabb[5] + 1)
288
+ coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
289
+ coords = torch.cat([
290
+ torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
291
+ coords.repeat(N, 1),
292
+ ], dim=1).to(dtype=torch.int32, device=device)
293
+ feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
294
+ return SparseTensor(feats=feats, coords=coords)
295
+
296
+ def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
297
+ new_cache = {}
298
+ for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
299
+ if k in self._spatial_cache:
300
+ new_cache[k] = self._spatial_cache[k]
301
+ if k in other._spatial_cache:
302
+ if k not in new_cache:
303
+ new_cache[k] = other._spatial_cache[k]
304
+ else:
305
+ new_cache[k].update(other._spatial_cache[k])
306
+ return new_cache
307
+
308
+ def __neg__(self) -> 'SparseTensor':
309
+ return self.replace(-self.feats)
310
+
311
+ def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
312
+ if isinstance(other, torch.Tensor):
313
+ try:
314
+ other = torch.broadcast_to(other, self.shape)
315
+ other = sparse_batch_broadcast(self, other)
316
+ except:
317
+ pass
318
+ if isinstance(other, SparseTensor):
319
+ other = other.feats
320
+ new_feats = op(self.feats, other)
321
+ new_tensor = self.replace(new_feats)
322
+ if isinstance(other, SparseTensor):
323
+ new_tensor._spatial_cache = self.__merge_sparse_cache(other)
324
+ return new_tensor
325
+
326
+ def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
327
+ return self.__elemwise__(other, torch.add)
328
+
329
+ def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
330
+ return self.__elemwise__(other, torch.add)
331
+
332
+ def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
333
+ return self.__elemwise__(other, torch.sub)
334
+
335
+ def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
336
+ return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
337
+
338
+ def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
339
+ return self.__elemwise__(other, torch.mul)
340
+
341
+ def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
342
+ return self.__elemwise__(other, torch.mul)
343
+
344
+ def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
345
+ return self.__elemwise__(other, torch.div)
346
+
347
+ def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
348
+ return self.__elemwise__(other, lambda x, y: torch.div(y, x))
349
+
350
+ def __getitem__(self, idx):
351
+ if isinstance(idx, int):
352
+ idx = [idx]
353
+ elif isinstance(idx, slice):
354
+ idx = range(*idx.indices(self.shape[0]))
355
+ elif isinstance(idx, torch.Tensor):
356
+ if idx.dtype == torch.bool:
357
+ assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
358
+ idx = idx.nonzero().squeeze(1)
359
+ elif idx.dtype in [torch.int32, torch.int64]:
360
+ assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
361
+ else:
362
+ raise ValueError(f"Unknown index type: {idx.dtype}")
363
+ else:
364
+ raise ValueError(f"Unknown index type: {type(idx)}")
365
+
366
+ coords = []
367
+ feats = []
368
+ for new_idx, old_idx in enumerate(idx):
369
+ coords.append(self.coords[self.layout[old_idx]].clone())
370
+ coords[-1][:, 0] = new_idx
371
+ feats.append(self.feats[self.layout[old_idx]])
372
+ coords = torch.cat(coords, dim=0).contiguous()
373
+ feats = torch.cat(feats, dim=0).contiguous()
374
+ return SparseTensor(feats=feats, coords=coords)
375
+
376
+ def register_spatial_cache(self, key, value) -> None:
377
+ """
378
+ Register a spatial cache.
379
+ The spatial cache can be any thing you want to cache.
380
+ The registery and retrieval of the cache is based on current scale.
381
+ """
382
+ scale_key = str(self._scale)
383
+ if scale_key not in self._spatial_cache:
384
+ self._spatial_cache[scale_key] = {}
385
+ self._spatial_cache[scale_key][key] = value
386
+
387
+ def get_spatial_cache(self, key=None):
388
+ """
389
+ Get a spatial cache.
390
+ """
391
+ scale_key = str(self._scale)
392
+ cur_scale_cache = self._spatial_cache.get(scale_key, {})
393
+ if key is None:
394
+ return cur_scale_cache
395
+ return cur_scale_cache.get(key, None)
396
+
397
+
398
+ def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
399
+ """
400
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
401
+
402
+ Args:
403
+ input (torch.Tensor): 1D tensor to broadcast.
404
+ target (SparseTensor): Sparse tensor to broadcast to.
405
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
406
+ """
407
+ coords, feats = input.coords, input.feats
408
+ broadcasted = torch.zeros_like(feats)
409
+ for k in range(input.shape[0]):
410
+ broadcasted[input.layout[k]] = other[k]
411
+ return broadcasted
412
+
413
+
414
+ def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
415
+ """
416
+ Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
417
+
418
+ Args:
419
+ input (torch.Tensor): 1D tensor to broadcast.
420
+ target (SparseTensor): Sparse tensor to broadcast to.
421
+ op (callable): Operation to perform after broadcasting. Defaults to torch.add.
422
+ """
423
+ return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
424
+
425
+
426
+ def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
427
+ """
428
+ Concatenate a list of sparse tensors.
429
+
430
+ Args:
431
+ inputs (List[SparseTensor]): List of sparse tensors to concatenate.
432
+ """
433
+ if dim == 0:
434
+ start = 0
435
+ coords = []
436
+ for input in inputs:
437
+ coords.append(input.coords.clone())
438
+ coords[-1][:, 0] += start
439
+ start += input.shape[0]
440
+ coords = torch.cat(coords, dim=0)
441
+ feats = torch.cat([input.feats for input in inputs], dim=0)
442
+ output = SparseTensor(
443
+ coords=coords,
444
+ feats=feats,
445
+ )
446
+ else:
447
+ feats = torch.cat([input.feats for input in inputs], dim=dim)
448
+ output = inputs[0].replace(feats)
449
+
450
+ return output
451
+
452
+
453
+ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
454
+ """
455
+ Unbind a sparse tensor along a dimension.
456
+
457
+ Args:
458
+ input (SparseTensor): Sparse tensor to unbind.
459
+ dim (int): Dimension to unbind.
460
+ """
461
+ if dim == 0:
462
+ return [input[i] for i in range(input.shape[0])]
463
+ else:
464
+ feats = input.feats.unbind(dim)
465
+ return [input.replace(f) for f in feats]
anigen/modules/sparse/conv/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import BACKEND
2
+
3
+
4
+ SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
5
+
6
+ def __from_env():
7
+ import os
8
+
9
+ global SPCONV_ALGO
10
+ env_spconv_algo = os.environ.get('SPCONV_ALGO')
11
+ if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
12
+ SPCONV_ALGO = env_spconv_algo
13
+ print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
14
+
15
+
16
+ __from_env()
17
+
18
+ if BACKEND == 'torchsparse':
19
+ from .conv_torchsparse import *
20
+ elif BACKEND == 'spconv':
21
+ from .conv_spconv import *
anigen/modules/sparse/conv/conv_spconv.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+ from .. import DEBUG
5
+ from . import SPCONV_ALGO
6
+
7
+ class SparseConv3d(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
9
+ super(SparseConv3d, self).__init__()
10
+ if 'spconv' not in globals():
11
+ import spconv.pytorch as spconv
12
+ algo = None
13
+ if SPCONV_ALGO == 'native':
14
+ algo = spconv.ConvAlgo.Native
15
+ elif SPCONV_ALGO == 'implicit_gemm':
16
+ algo = spconv.ConvAlgo.MaskImplicitGemm
17
+ if stride == 1 and (padding is None):
18
+ self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
19
+ else:
20
+ self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
21
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
22
+ self.padding = padding
23
+
24
+ def forward(self, x: SparseTensor) -> SparseTensor:
25
+ spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
26
+ new_data = self.conv(x.data)
27
+ new_shape = [x.shape[0], self.conv.out_channels]
28
+ new_layout = None if spatial_changed else x.layout
29
+
30
+ if spatial_changed and (x.shape[0] != 1):
31
+ # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
32
+ fwd = new_data.indices[:, 0].argsort()
33
+ bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
34
+ sorted_feats = new_data.features[fwd]
35
+ sorted_coords = new_data.indices[fwd]
36
+ unsorted_data = new_data
37
+ new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
38
+
39
+ out = SparseTensor(
40
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
41
+ scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
42
+ spatial_cache=x._spatial_cache,
43
+ )
44
+
45
+ if spatial_changed and (x.shape[0] != 1):
46
+ out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
47
+ out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
48
+
49
+ return out
50
+
51
+
52
+ class SparseInverseConv3d(nn.Module):
53
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
54
+ super(SparseInverseConv3d, self).__init__()
55
+ if 'spconv' not in globals():
56
+ import spconv.pytorch as spconv
57
+ self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
58
+ self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
59
+
60
+ def forward(self, x: SparseTensor) -> SparseTensor:
61
+ spatial_changed = any(s != 1 for s in self.stride)
62
+ if spatial_changed:
63
+ # recover the original spconv order
64
+ data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
65
+ bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
66
+ data = data.replace_feature(x.feats[bwd])
67
+ if DEBUG:
68
+ assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
69
+ else:
70
+ data = x.data
71
+
72
+ new_data = self.conv(data)
73
+ new_shape = [x.shape[0], self.conv.out_channels]
74
+ new_layout = None if spatial_changed else x.layout
75
+ out = SparseTensor(
76
+ new_data, shape=torch.Size(new_shape), layout=new_layout,
77
+ scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
78
+ spatial_cache=x._spatial_cache,
79
+ )
80
+ return out
anigen/modules/sparse/conv/conv_torchsparse.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import SparseTensor
4
+
5
+
6
+ class SparseConv3d(nn.Module):
7
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
8
+ super(SparseConv3d, self).__init__()
9
+ if 'torchsparse' not in globals():
10
+ import torchsparse
11
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
12
+
13
+ def forward(self, x: SparseTensor) -> SparseTensor:
14
+ out = self.conv(x.data)
15
+ new_shape = [x.shape[0], self.conv.out_channels]
16
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
17
+ out._spatial_cache = x._spatial_cache
18
+ out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
19
+ return out
20
+
21
+
22
+ class SparseInverseConv3d(nn.Module):
23
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
24
+ super(SparseInverseConv3d, self).__init__()
25
+ if 'torchsparse' not in globals():
26
+ import torchsparse
27
+ self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
28
+
29
+ def forward(self, x: SparseTensor) -> SparseTensor:
30
+ out = self.conv(x.data)
31
+ new_shape = [x.shape[0], self.conv.out_channels]
32
+ out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
33
+ out._spatial_cache = x._spatial_cache
34
+ out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
35
+ return out
36
+
37
+
38
+
anigen/modules/sparse/linear.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+
5
+ __all__ = [
6
+ 'SparseLinear'
7
+ ]
8
+
9
+
10
+ class SparseLinear(nn.Linear):
11
+ def __init__(self, in_features, out_features, bias=True):
12
+ super(SparseLinear, self).__init__(in_features, out_features, bias)
13
+
14
+ def forward(self, input: SparseTensor) -> SparseTensor:
15
+ return input.replace(super().forward(input.feats))
anigen/modules/sparse/nonlinearity.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+
5
+ __all__ = [
6
+ 'SparseReLU',
7
+ 'SparseSiLU',
8
+ 'SparseGELU',
9
+ 'SparseActivation'
10
+ ]
11
+
12
+
13
+ class SparseReLU(nn.ReLU):
14
+ def forward(self, input: SparseTensor) -> SparseTensor:
15
+ return input.replace(super().forward(input.feats))
16
+
17
+
18
+ class SparseSiLU(nn.SiLU):
19
+ def forward(self, input: SparseTensor) -> SparseTensor:
20
+ return input.replace(super().forward(input.feats))
21
+
22
+
23
+ class SparseGELU(nn.GELU):
24
+ def forward(self, input: SparseTensor) -> SparseTensor:
25
+ return input.replace(super().forward(input.feats))
26
+
27
+
28
+ class SparseActivation(nn.Module):
29
+ def __init__(self, activation: nn.Module):
30
+ super().__init__()
31
+ self.activation = activation
32
+
33
+ def forward(self, input: SparseTensor) -> SparseTensor:
34
+ return input.replace(self.activation(input.feats))
35
+
anigen/modules/sparse/norm.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import SparseTensor
4
+ from . import DEBUG
5
+
6
+ __all__ = [
7
+ 'SparseGroupNorm',
8
+ 'SparseLayerNorm',
9
+ 'SparseGroupNorm32',
10
+ 'SparseLayerNorm32',
11
+ ]
12
+
13
+
14
+ class SparseGroupNorm(nn.GroupNorm):
15
+ def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
16
+ super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
17
+
18
+ def forward(self, input: SparseTensor) -> SparseTensor:
19
+ nfeats = torch.zeros_like(input.feats)
20
+ for k in range(input.shape[0]):
21
+ if DEBUG:
22
+ assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
23
+ bfeats = input.feats[input.layout[k]]
24
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
25
+ bfeats = super().forward(bfeats)
26
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
27
+ nfeats[input.layout[k]] = bfeats
28
+ return input.replace(nfeats)
29
+
30
+
31
+ class SparseLayerNorm(nn.LayerNorm):
32
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
33
+ super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
34
+
35
+ def forward(self, input: SparseTensor) -> SparseTensor:
36
+ nfeats = torch.zeros_like(input.feats)
37
+ for k in range(input.shape[0]):
38
+ bfeats = input.feats[input.layout[k]]
39
+ bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
40
+ bfeats = super().forward(bfeats)
41
+ bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
42
+ nfeats[input.layout[k]] = bfeats
43
+ return input.replace(nfeats)
44
+
45
+
46
+ class SparseGroupNorm32(SparseGroupNorm):
47
+ """
48
+ A GroupNorm layer that converts to float32 before the forward pass.
49
+ """
50
+ def forward(self, x: SparseTensor) -> SparseTensor:
51
+ return super().forward(x.float()).type(x.dtype)
52
+
53
+ class SparseLayerNorm32(SparseLayerNorm):
54
+ """
55
+ A LayerNorm layer that converts to float32 before the forward pass.
56
+ """
57
+ def forward(self, x: SparseTensor) -> SparseTensor:
58
+ return super().forward(x.float()).type(x.dtype)
anigen/modules/sparse/spatial.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from . import SparseTensor
5
+
6
+ __all__ = [
7
+ 'SparseDownsample',
8
+ 'SparseUpsample',
9
+ 'SparseSubdivide'
10
+ ]
11
+
12
+
13
+ class SparseDownsample(nn.Module):
14
+ """
15
+ Downsample a sparse tensor by a factor of `factor`.
16
+ Implemented as average pooling.
17
+ """
18
+ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
19
+ super(SparseDownsample, self).__init__()
20
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
21
+
22
+ def forward(self, input: SparseTensor) -> SparseTensor:
23
+ DIM = input.coords.shape[-1] - 1
24
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
25
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
26
+
27
+ coord = list(input.coords.unbind(dim=-1))
28
+ for i, f in enumerate(factor):
29
+ coord[i+1] = coord[i+1] // f
30
+
31
+ MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
32
+ OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
33
+ code = sum([c * o for c, o in zip(coord, OFFSET)])
34
+ code, idx = code.unique(return_inverse=True)
35
+
36
+ new_feats = torch.scatter_reduce(
37
+ torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
38
+ dim=0,
39
+ index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
40
+ src=input.feats,
41
+ reduce='mean'
42
+ )
43
+ new_coords = torch.stack(
44
+ [code // OFFSET[0]] +
45
+ [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
46
+ dim=-1
47
+ )
48
+ out = SparseTensor(new_feats, new_coords, input.shape,)
49
+ out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
50
+ out._spatial_cache = input._spatial_cache
51
+
52
+ out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
53
+ out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
54
+ out.register_spatial_cache(f'upsample_{factor}_idx', idx)
55
+
56
+ return out
57
+
58
+
59
+ class SparseUpsample(nn.Module):
60
+ """
61
+ Upsample a sparse tensor by a factor of `factor`.
62
+ Implemented as nearest neighbor interpolation.
63
+ """
64
+ def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
65
+ super(SparseUpsample, self).__init__()
66
+ self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
67
+
68
+ def forward(self, input: SparseTensor) -> SparseTensor:
69
+ DIM = input.coords.shape[-1] - 1
70
+ factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
71
+ assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
72
+
73
+ new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
74
+ new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
75
+ idx = input.get_spatial_cache(f'upsample_{factor}_idx')
76
+ if any([x is None for x in [new_coords, new_layout, idx]]):
77
+ raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
78
+ new_feats = input.feats[idx]
79
+ out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
80
+ out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
81
+ out._spatial_cache = input._spatial_cache
82
+ return out
83
+
84
+ class SparseSubdivide(nn.Module):
85
+ """
86
+ Upsample a sparse tensor by a factor of `factor`.
87
+ Implemented as nearest neighbor interpolation.
88
+ """
89
+ def __init__(self):
90
+ super(SparseSubdivide, self).__init__()
91
+
92
+ def forward(self, input: SparseTensor) -> SparseTensor:
93
+ DIM = input.coords.shape[-1] - 1
94
+ # upsample scale=2^DIM
95
+ n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
96
+ n_coords = torch.nonzero(n_cube)
97
+ n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
98
+ factor = n_coords.shape[0]
99
+ assert factor == 2 ** DIM
100
+ # print(n_coords.shape)
101
+ new_coords = input.coords.clone()
102
+ new_coords[:, 1:] *= 2
103
+ new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
104
+
105
+ new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
106
+ out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
107
+ out._scale = input._scale * 2
108
+ out._spatial_cache = input._spatial_cache
109
+ return out
110
+
anigen/modules/sparse/transformer/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
3
+ from .anigen_modulated import *
anigen/modules/sparse/transformer/anigen_modulated.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+ from ..attention import SparseMultiHeadAttention, SerializeMode
6
+ from ...norm import LayerNorm32
7
+ from .blocks import SparseFeedForwardNet
8
+
9
+
10
+ class AniGenModulatedSparseTransformerCrossBlock(nn.Module):
11
+ """
12
+ AniGen Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
13
+ """
14
+ def __init__(
15
+ self,
16
+ channels: int,
17
+ channels_skl: int,
18
+ ctx_channels: int,
19
+ num_heads: int,
20
+ num_heads_skl: int,
21
+ mlp_ratio: float = 4.0,
22
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
23
+ window_size: Optional[int] = None,
24
+ shift_sequence: Optional[int] = None,
25
+ shift_window: Optional[Tuple[int, int, int]] = None,
26
+ serialize_mode: Optional[SerializeMode] = None,
27
+ use_checkpoint: bool = False,
28
+ use_rope: bool = False,
29
+ qk_rms_norm: bool = False,
30
+ qk_rms_norm_cross: bool = False,
31
+ qkv_bias: bool = True,
32
+ share_mod: bool = False,
33
+
34
+ ):
35
+ super().__init__()
36
+ self.use_checkpoint = use_checkpoint
37
+ self.share_mod = share_mod
38
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
39
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
40
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
41
+ self.norm1_skl = LayerNorm32(channels_skl, elementwise_affine=False, eps=1e-6)
42
+ self.norm2_skl = LayerNorm32(channels_skl, elementwise_affine=True, eps=1e-6)
43
+ self.norm3_skl = LayerNorm32(channels_skl, elementwise_affine=False, eps=1e-6)
44
+ self.attn = SparseMultiHeadAttention(
45
+ channels,
46
+ ctx_channels=channels_skl,
47
+ num_heads=num_heads,
48
+ type="cross",
49
+ attn_mode=attn_mode,
50
+ window_size=window_size,
51
+ shift_sequence=shift_sequence,
52
+ shift_window=shift_window,
53
+ serialize_mode=serialize_mode,
54
+ qkv_bias=qkv_bias,
55
+ use_rope=use_rope,
56
+ qk_rms_norm=qk_rms_norm,
57
+ )
58
+ self.attn_skl = SparseMultiHeadAttention(
59
+ channels_skl,
60
+ ctx_channels=channels,
61
+ num_heads=num_heads_skl,
62
+ type="cross",
63
+ attn_mode=attn_mode,
64
+ window_size=window_size,
65
+ shift_sequence=shift_sequence,
66
+ shift_window=shift_window,
67
+ serialize_mode=serialize_mode,
68
+ qkv_bias=qkv_bias,
69
+ use_rope=use_rope,
70
+ qk_rms_norm=qk_rms_norm,
71
+ )
72
+ self.context_cross_attn = SparseMultiHeadAttention(
73
+ channels,
74
+ ctx_channels=ctx_channels,
75
+ num_heads=num_heads,
76
+ type="cross",
77
+ attn_mode="full",
78
+ qkv_bias=qkv_bias,
79
+ qk_rms_norm=qk_rms_norm_cross,
80
+ )
81
+ self.context_cross_attn_skl = SparseMultiHeadAttention(
82
+ channels_skl,
83
+ ctx_channels=ctx_channels,
84
+ num_heads=num_heads_skl,
85
+ type="cross",
86
+ attn_mode="full",
87
+ qkv_bias=qkv_bias,
88
+ qk_rms_norm=qk_rms_norm_cross,
89
+ )
90
+ self.mlp = SparseFeedForwardNet(
91
+ channels,
92
+ mlp_ratio=mlp_ratio,
93
+ )
94
+ self.mlp_skl = SparseFeedForwardNet(
95
+ channels_skl,
96
+ mlp_ratio=mlp_ratio,
97
+ )
98
+ if not share_mod:
99
+ self.adaLN_modulation = nn.Sequential(
100
+ nn.SiLU(),
101
+ nn.Linear(channels, 6 * channels, bias=True)
102
+ )
103
+ self.adaLN_modulation_skl = nn.Sequential(
104
+ nn.SiLU(),
105
+ nn.Linear(channels_skl, 6 * channels_skl, bias=True)
106
+ )
107
+
108
+ def _forward(self, x: SparseTensor, x_skl: SparseTensor, mod: torch.Tensor, mod_skl: torch.Tensor, context: torch.Tensor) -> SparseTensor:
109
+ if self.share_mod:
110
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
111
+ shift_msa_skl, scale_msa_skl, gate_msa_skl, shift_mlp_skl, scale_mlp_skl, gate_mlp_skl = mod_skl.chunk(6, dim=1)
112
+ else:
113
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
114
+ shift_msa_skl, scale_msa_skl, gate_msa_skl, shift_mlp_skl, scale_mlp_skl, gate_mlp_skl = self.adaLN_modulation_skl(mod_skl).chunk(6, dim=1)
115
+ # Input Norm
116
+ h = x.replace(self.norm1(x.feats))
117
+ h_skl = x_skl.replace(self.norm1_skl(x_skl.feats))
118
+ # AdaLN (By Time Step)
119
+ h = h * (1 + scale_msa) + shift_msa
120
+ h_skl = h_skl * (1 + scale_msa_skl) + shift_msa_skl
121
+ # Self Attn (Cross shape and skeleton)
122
+ h = self.attn(h, h_skl)
123
+ h_skl = self.attn_skl(h_skl, h)
124
+ # Gated Residual (By Time Step)
125
+ h = h * gate_msa
126
+ h_skl = h_skl * gate_msa_skl
127
+ x = x + h
128
+ x_skl = x_skl + h_skl
129
+ # Context Cross Attention
130
+ h = x.replace(self.norm2(x.feats))
131
+ h_skl = x_skl.replace(self.norm2_skl(x_skl.feats))
132
+ h = self.context_cross_attn(h, context)
133
+ h_skl = self.context_cross_attn_skl(h_skl, context)
134
+ x = x + h
135
+ x_skl = x_skl + h_skl
136
+ # Re-Centered
137
+ h = x.replace(self.norm3(x.feats))
138
+ h_skl = x_skl.replace(self.norm3_skl(x_skl.feats))
139
+ h = h * (1 + scale_mlp) + shift_mlp
140
+ h_skl = h_skl * (1 + scale_mlp_skl) + shift_mlp_skl
141
+ # Output MLP
142
+ h = self.mlp(h)
143
+ h_skl = self.mlp_skl(h_skl)
144
+ # Gated Residual (By Time Step)
145
+ h = h * gate_mlp
146
+ h_skl = h_skl * gate_mlp_skl
147
+ x = x + h
148
+ x_skl = x_skl + h_skl
149
+ return x, x_skl
150
+
151
+ def forward(self, x: SparseTensor, x_skl: SparseTensor, mod: torch.Tensor, mod_skl: torch.Tensor, context: torch.Tensor) -> SparseTensor:
152
+ if self.use_checkpoint:
153
+ return torch.utils.checkpoint.checkpoint(self._forward, x, x_skl, mod, mod_skl, context, use_reentrant=False)
154
+ else:
155
+ return self._forward(x, x_skl, mod, mod_skl, context)
anigen/modules/sparse/transformer/blocks.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+ from ..linear import SparseLinear
6
+ from ..nonlinearity import SparseGELU
7
+ from ..attention import SparseMultiHeadAttention, SerializeMode
8
+ from ...norm import LayerNorm32
9
+
10
+
11
+ class SparseFeedForwardNet(nn.Module):
12
+ def __init__(self, channels: int, mlp_ratio: float = 4.0):
13
+ super().__init__()
14
+ self.mlp = nn.Sequential(
15
+ SparseLinear(channels, int(channels * mlp_ratio)),
16
+ SparseGELU(approximate="tanh"),
17
+ SparseLinear(int(channels * mlp_ratio), channels),
18
+ )
19
+
20
+ def forward(self, x: SparseTensor) -> SparseTensor:
21
+ return self.mlp(x)
22
+
23
+
24
+ class SparseTransformerBlock(nn.Module):
25
+ """
26
+ Sparse Transformer block (MSA + FFN).
27
+ """
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
34
+ window_size: Optional[int] = None,
35
+ shift_sequence: Optional[int] = None,
36
+ shift_window: Optional[Tuple[int, int, int]] = None,
37
+ serialize_mode: Optional[SerializeMode] = None,
38
+ use_checkpoint: bool = False,
39
+ use_rope: bool = False,
40
+ qk_rms_norm: bool = False,
41
+ qkv_bias: bool = True,
42
+ ln_affine: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.use_checkpoint = use_checkpoint
46
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
47
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
48
+ self.attn = SparseMultiHeadAttention(
49
+ channels,
50
+ num_heads=num_heads,
51
+ attn_mode=attn_mode,
52
+ window_size=window_size,
53
+ shift_sequence=shift_sequence,
54
+ shift_window=shift_window,
55
+ serialize_mode=serialize_mode,
56
+ qkv_bias=qkv_bias,
57
+ use_rope=use_rope,
58
+ qk_rms_norm=qk_rms_norm,
59
+ )
60
+ self.mlp = SparseFeedForwardNet(
61
+ channels,
62
+ mlp_ratio=mlp_ratio,
63
+ )
64
+
65
+ def _forward(self, x: SparseTensor) -> SparseTensor:
66
+ # Self-attention
67
+ h = x.replace(self.norm1(x.feats))
68
+ h = self.attn(h)
69
+ x = x + h
70
+ # Feed-forward network
71
+ h = x.replace(self.norm2(x.feats))
72
+ h = self.mlp(h)
73
+ x = x + h
74
+ return x
75
+
76
+ def forward(self, x: SparseTensor) -> SparseTensor:
77
+ if self.use_checkpoint:
78
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
79
+ else:
80
+ return self._forward(x)
81
+
82
+
83
+ class SparseTransformerCrossBlock(nn.Module):
84
+ """
85
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
86
+ """
87
+ def __init__(
88
+ self,
89
+ channels: int,
90
+ ctx_channels: int,
91
+ num_heads: int,
92
+ mlp_ratio: float = 4.0,
93
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
94
+ attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
95
+ window_size: Optional[int] = None,
96
+ shift_sequence: Optional[int] = None,
97
+ shift_window: Optional[Tuple[int, int, int]] = None,
98
+ serialize_mode: Optional[SerializeMode] = None,
99
+ use_checkpoint: bool = False,
100
+ use_rope: bool = False,
101
+ qk_rms_norm: bool = False,
102
+ qk_rms_norm_cross: bool = False,
103
+ qkv_bias: bool = True,
104
+ ln_affine: bool = False,
105
+ context_is_dual: bool = False,
106
+ ):
107
+ super().__init__()
108
+ self.use_checkpoint = use_checkpoint
109
+ self.context_is_dual = context_is_dual
110
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
111
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
112
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
113
+ if context_is_dual:
114
+ self.norm4 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
115
+ self.self_attn = SparseMultiHeadAttention(
116
+ channels,
117
+ num_heads=num_heads,
118
+ type="self",
119
+ attn_mode=attn_mode,
120
+ window_size=window_size,
121
+ shift_sequence=shift_sequence,
122
+ shift_window=shift_window,
123
+ serialize_mode=serialize_mode,
124
+ qkv_bias=qkv_bias,
125
+ use_rope=use_rope,
126
+ qk_rms_norm=qk_rms_norm,
127
+ )
128
+ self.cross_attn = SparseMultiHeadAttention(
129
+ channels,
130
+ ctx_channels=ctx_channels,
131
+ num_heads=num_heads,
132
+ type="cross",
133
+ attn_mode=attn_mode_cross,
134
+ window_size=window_size,
135
+ shift_sequence=shift_sequence,
136
+ shift_window=shift_window,
137
+ serialize_mode=serialize_mode,
138
+ qkv_bias=qkv_bias,
139
+ qk_rms_norm=qk_rms_norm_cross,
140
+ )
141
+ self.mlp = SparseFeedForwardNet(
142
+ channels,
143
+ mlp_ratio=mlp_ratio,
144
+ )
145
+
146
+ def _forward(self, x: SparseTensor, context: SparseTensor):
147
+ # Self-attention
148
+ h = x.replace(self.norm1(x.feats))
149
+ h = self.self_attn(h)
150
+ x = x + h
151
+ # Cross-attention
152
+ h = x.replace(self.norm2(x.feats))
153
+ if self.context_is_dual:
154
+ context = context.replace(self.norm4(context.feats))
155
+ h = self.cross_attn(h, context)
156
+ x = x + h
157
+ # Feed-forward network
158
+ h = x.replace(self.norm3(x.feats))
159
+ h = self.mlp(h)
160
+ x = x + h
161
+ return x
162
+
163
+ def forward(self, x: SparseTensor, context: SparseTensor):
164
+ if self.use_checkpoint:
165
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
166
+ else:
167
+ return self._forward(x, context)
168
+
169
+
170
+ class SparseTransformerMultiContextCrossBlock(nn.Module):
171
+ """
172
+ Sparse Transformer cross-attention block (MSA + MCA + FFN).
173
+ """
174
+ def __init__(
175
+ self,
176
+ channels: int,
177
+ ctx_channels: List[int],
178
+ num_heads: int,
179
+ mlp_ratio: float = 4.0,
180
+ attn_mode: Literal["full", "serialized", "windowed"] = "full",
181
+ attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
182
+ window_size: Optional[int] = None,
183
+ shift_sequence: Optional[int] = None,
184
+ shift_window: Optional[Tuple[int, int, int]] = None,
185
+ serialize_mode: Optional[SerializeMode] = None,
186
+ use_checkpoint: bool = False,
187
+ use_rope: bool = False,
188
+ qk_rms_norm: bool = False,
189
+ qk_rms_norm_cross: bool = False,
190
+ qkv_bias: bool = True,
191
+ ln_affine: bool = False,
192
+ cross_attn_cache_suffix: str = '',
193
+ ):
194
+ super().__init__()
195
+ self.context_num = len(ctx_channels)
196
+ self.use_checkpoint = use_checkpoint
197
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
198
+ if self.context_num > 0:
199
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
200
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
201
+
202
+ self.self_attn = SparseMultiHeadAttention(
203
+ channels,
204
+ num_heads=num_heads,
205
+ type="self",
206
+ attn_mode=attn_mode,
207
+ window_size=window_size,
208
+ shift_sequence=shift_sequence,
209
+ shift_window=shift_window,
210
+ serialize_mode=serialize_mode,
211
+ qkv_bias=qkv_bias,
212
+ use_rope=use_rope,
213
+ qk_rms_norm=qk_rms_norm,
214
+ )
215
+ for i in range(self.context_num):
216
+ setattr(self, f'cross_attn_{i}', SparseMultiHeadAttention(
217
+ channels,
218
+ ctx_channels=ctx_channels[i],
219
+ num_heads=num_heads,
220
+ type="cross",
221
+ attn_mode=attn_mode_cross,
222
+ window_size=window_size,
223
+ shift_sequence=shift_sequence,
224
+ shift_window=shift_window,
225
+ serialize_mode=serialize_mode,
226
+ qkv_bias=qkv_bias,
227
+ qk_rms_norm=qk_rms_norm_cross,
228
+ cross_attn_cache_suffix=cross_attn_cache_suffix + f'_modality_{i}'
229
+ ))
230
+ setattr(self, f'ctx_norm_{i}', LayerNorm32(ctx_channels[i], elementwise_affine=True, eps=1e-6))
231
+
232
+ self.mlp = SparseFeedForwardNet(
233
+ channels,
234
+ mlp_ratio=mlp_ratio,
235
+ )
236
+
237
+ def _forward(self, x: SparseTensor, contexts: List[SparseTensor]) -> SparseTensor:
238
+ # Self-attention
239
+ h = x.replace(self.norm1(x.feats))
240
+ h = self.self_attn(h)
241
+ x = x + h
242
+ # Cross-attention
243
+ if self.context_num > 0 and len(contexts) > 0:
244
+ h_norm = x.replace(self.norm2(x.feats))
245
+ for i, context in enumerate(contexts):
246
+ context = context.replace(getattr(self, f'ctx_norm_{i}')(context.feats))
247
+ h = getattr(self, f'cross_attn_{i}')(h_norm, context)
248
+ x = x + h
249
+ # Feed-forward network
250
+ h = x.replace(self.norm3(x.feats))
251
+ h = self.mlp(h)
252
+ x = x + h
253
+ return x
254
+
255
+ def forward(self, x: SparseTensor, contexts: List[SparseTensor]) -> SparseTensor:
256
+ if self.use_checkpoint:
257
+ return torch.utils.checkpoint.checkpoint(self._forward, x, contexts, use_reentrant=False)
258
+ else:
259
+ return self._forward(x, contexts)
anigen/modules/sparse/transformer/modulated.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..basic import SparseTensor
5
+ from ..attention import SparseMultiHeadAttention, SerializeMode
6
+ from ...norm import LayerNorm32
7
+ from .blocks import SparseFeedForwardNet
8
+
9
+
10
+ class ModulatedSparseTransformerBlock(nn.Module):
11
+ """
12
+ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
13
+ """
14
+ def __init__(
15
+ self,
16
+ channels: int,
17
+ num_heads: int,
18
+ mlp_ratio: float = 4.0,
19
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
20
+ window_size: Optional[int] = None,
21
+ shift_sequence: Optional[int] = None,
22
+ shift_window: Optional[Tuple[int, int, int]] = None,
23
+ serialize_mode: Optional[SerializeMode] = None,
24
+ use_checkpoint: bool = False,
25
+ use_rope: bool = False,
26
+ qk_rms_norm: bool = False,
27
+ qkv_bias: bool = True,
28
+ share_mod: bool = False,
29
+ ):
30
+ super().__init__()
31
+ self.use_checkpoint = use_checkpoint
32
+ self.share_mod = share_mod
33
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
34
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
35
+ self.attn = SparseMultiHeadAttention(
36
+ channels,
37
+ num_heads=num_heads,
38
+ attn_mode=attn_mode,
39
+ window_size=window_size,
40
+ shift_sequence=shift_sequence,
41
+ shift_window=shift_window,
42
+ serialize_mode=serialize_mode,
43
+ qkv_bias=qkv_bias,
44
+ use_rope=use_rope,
45
+ qk_rms_norm=qk_rms_norm,
46
+ )
47
+ self.mlp = SparseFeedForwardNet(
48
+ channels,
49
+ mlp_ratio=mlp_ratio,
50
+ )
51
+ if not share_mod:
52
+ self.adaLN_modulation = nn.Sequential(
53
+ nn.SiLU(),
54
+ nn.Linear(channels, 6 * channels, bias=True)
55
+ )
56
+
57
+ def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
58
+ if self.share_mod:
59
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
60
+ else:
61
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
62
+ h = x.replace(self.norm1(x.feats))
63
+ h = h * (1 + scale_msa) + shift_msa
64
+ h = self.attn(h)
65
+ h = h * gate_msa
66
+ x = x + h
67
+ h = x.replace(self.norm2(x.feats))
68
+ h = h * (1 + scale_mlp) + shift_mlp
69
+ h = self.mlp(h)
70
+ h = h * gate_mlp
71
+ x = x + h
72
+ return x
73
+
74
+ def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
75
+ if self.use_checkpoint:
76
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
77
+ else:
78
+ return self._forward(x, mod)
79
+
80
+
81
+ class ModulatedSparseTransformerCrossBlock(nn.Module):
82
+ """
83
+ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
84
+ """
85
+ def __init__(
86
+ self,
87
+ channels: int,
88
+ ctx_channels: int,
89
+ num_heads: int,
90
+ mlp_ratio: float = 4.0,
91
+ attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
92
+ window_size: Optional[int] = None,
93
+ shift_sequence: Optional[int] = None,
94
+ shift_window: Optional[Tuple[int, int, int]] = None,
95
+ serialize_mode: Optional[SerializeMode] = None,
96
+ use_checkpoint: bool = False,
97
+ use_rope: bool = False,
98
+ qk_rms_norm: bool = False,
99
+ qk_rms_norm_cross: bool = False,
100
+ qkv_bias: bool = True,
101
+ share_mod: bool = False,
102
+ norm_for_context: bool = False,
103
+ ):
104
+ super().__init__()
105
+ self.use_checkpoint = use_checkpoint
106
+ self.share_mod = share_mod
107
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
108
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
109
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
110
+ self.norm_for_context = norm_for_context
111
+ if self.norm_for_context:
112
+ self.context_norm = LayerNorm32(ctx_channels, elementwise_affine=True, eps=1e-6)
113
+ self.self_attn = SparseMultiHeadAttention(
114
+ channels,
115
+ num_heads=num_heads,
116
+ type="self",
117
+ attn_mode=attn_mode,
118
+ window_size=window_size,
119
+ shift_sequence=shift_sequence,
120
+ shift_window=shift_window,
121
+ serialize_mode=serialize_mode,
122
+ qkv_bias=qkv_bias,
123
+ use_rope=use_rope,
124
+ qk_rms_norm=qk_rms_norm,
125
+ )
126
+ self.cross_attn = SparseMultiHeadAttention(
127
+ channels,
128
+ ctx_channels=ctx_channels,
129
+ num_heads=num_heads,
130
+ type="cross",
131
+ attn_mode="full",
132
+ qkv_bias=qkv_bias,
133
+ qk_rms_norm=qk_rms_norm_cross,
134
+ )
135
+ self.mlp = SparseFeedForwardNet(
136
+ channels,
137
+ mlp_ratio=mlp_ratio,
138
+ )
139
+ if not share_mod:
140
+ self.adaLN_modulation = nn.Sequential(
141
+ nn.SiLU(),
142
+ nn.Linear(channels, 6 * channels, bias=True)
143
+ )
144
+
145
+ def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
146
+ if self.share_mod:
147
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
148
+ else:
149
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
150
+ h = x.replace(self.norm1(x.feats))
151
+ h = h * (1 + scale_msa) + shift_msa
152
+ h = self.self_attn(h)
153
+ h = h * gate_msa
154
+ x = x + h
155
+ h = x.replace(self.norm2(x.feats))
156
+ if self.norm_for_context:
157
+ if isinstance(context, SparseTensor):
158
+ context = context.replace(self.context_norm(context.feats))
159
+ else:
160
+ context = self.context_norm(context)
161
+ h = self.cross_attn(h, context)
162
+ x = x + h
163
+ h = x.replace(self.norm3(x.feats))
164
+ h = h * (1 + scale_mlp) + shift_mlp
165
+ h = self.mlp(h)
166
+ h = h * gate_mlp
167
+ x = x + h
168
+ return x
169
+
170
+ def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
171
+ if self.use_checkpoint:
172
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
173
+ else:
174
+ return self._forward(x, mod, context)
anigen/modules/spatial.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
5
+ """
6
+ 3D pixel shuffle.
7
+ """
8
+ B, C, H, W, D = x.shape
9
+ C_ = C // scale_factor**3
10
+ x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
11
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
12
+ x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
13
+ return x
14
+
15
+
16
+ def patchify(x: torch.Tensor, patch_size: int):
17
+ """
18
+ Patchify a tensor.
19
+
20
+ Args:
21
+ x (torch.Tensor): (N, C, *spatial) tensor
22
+ patch_size (int): Patch size
23
+ """
24
+ DIM = x.dim() - 2
25
+ for d in range(2, DIM + 2):
26
+ assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
27
+
28
+ x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
29
+ x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
30
+ x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
31
+ return x
32
+
33
+
34
+ def unpatchify(x: torch.Tensor, patch_size: int):
35
+ """
36
+ Unpatchify a tensor.
37
+
38
+ Args:
39
+ x (torch.Tensor): (N, C, *spatial) tensor
40
+ patch_size (int): Patch size
41
+ """
42
+ DIM = x.dim() - 2
43
+ assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
44
+
45
+ x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
46
+ x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
47
+ x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
48
+ return x
anigen/modules/transformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .blocks import *
2
+ from .modulated import *
anigen/modules/transformer/blocks.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+
7
+
8
+ class AbsolutePositionEmbedder(nn.Module):
9
+ """
10
+ Embeds spatial positions into vector representations.
11
+ """
12
+ def __init__(self, channels: int, in_channels: int = 3):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.in_channels = in_channels
16
+ self.freq_dim = channels // in_channels // 2
17
+ self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
18
+ self.freqs = 1.0 / (10000 ** self.freqs)
19
+
20
+ def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Create sinusoidal position embeddings.
23
+
24
+ Args:
25
+ x: a 1-D Tensor of N indices
26
+
27
+ Returns:
28
+ an (N, D) Tensor of positional embeddings.
29
+ """
30
+ self.freqs = self.freqs.to(x.device)
31
+ out = torch.outer(x, self.freqs)
32
+ out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
33
+ return out
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Args:
38
+ x (torch.Tensor): (N, D) tensor of spatial positions
39
+ """
40
+ N, D = x.shape
41
+ assert D == self.in_channels, "Input dimension must match number of input channels"
42
+ embed = self._sin_cos_embedding(x.reshape(-1))
43
+ embed = embed.reshape(N, -1)
44
+ if embed.shape[1] < self.channels:
45
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
46
+ return embed
47
+
48
+
49
+ class FeedForwardNet(nn.Module):
50
+ def __init__(self, channels: int, mlp_ratio: float = 4.0, out_channels: Optional[int] = None):
51
+ super().__init__()
52
+ if out_channels is None:
53
+ out_channels = channels
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(channels, int(channels * mlp_ratio)),
56
+ nn.GELU(approximate="tanh"),
57
+ nn.Linear(int(channels * mlp_ratio), out_channels),
58
+ )
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ return self.mlp(x)
62
+
63
+
64
+ class TransformerBlock(nn.Module):
65
+ """
66
+ Transformer block (MSA + FFN).
67
+ """
68
+ def __init__(
69
+ self,
70
+ channels: int,
71
+ num_heads: int,
72
+ out_channels: Optional[int] = None,
73
+ mlp_ratio: float = 4.0,
74
+ attn_mode: Literal["full", "windowed"] = "full",
75
+ window_size: Optional[int] = None,
76
+ shift_window: Optional[int] = None,
77
+ use_checkpoint: bool = False,
78
+ use_rope: bool = False,
79
+ qk_rms_norm: bool = False,
80
+ qkv_bias: bool = True,
81
+ ln_affine: bool = False,
82
+ ):
83
+ super().__init__()
84
+ self.use_checkpoint = use_checkpoint
85
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
86
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
87
+ self.attn = MultiHeadAttention(
88
+ channels,
89
+ num_heads=num_heads,
90
+ attn_mode=attn_mode,
91
+ window_size=window_size,
92
+ shift_window=shift_window,
93
+ qkv_bias=qkv_bias,
94
+ use_rope=use_rope,
95
+ qk_rms_norm=qk_rms_norm,
96
+ )
97
+ self.channels = channels
98
+ self.out_channels = out_channels if out_channels is not None else channels
99
+ self.mlp = FeedForwardNet(
100
+ self.channels,
101
+ out_channels=self.out_channels,
102
+ mlp_ratio=mlp_ratio,
103
+ )
104
+ if self.out_channels != self.channels:
105
+ self.res_mlp = FeedForwardNet(
106
+ self.channels,
107
+ out_channels=self.out_channels,
108
+ mlp_ratio=1.0,
109
+ )
110
+
111
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ h = self.norm1(x)
113
+ h = self.attn(h)
114
+ x = x + h
115
+ h = self.norm2(x)
116
+ h = self.mlp(h)
117
+ if self.out_channels != self.channels:
118
+ x = self.res_mlp(x)
119
+ x = x + h
120
+ return x
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ if self.use_checkpoint:
124
+ return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
125
+ else:
126
+ return self._forward(x)
127
+
128
+
129
+ class SkinTransformerCrossBlock(nn.Module):
130
+ """
131
+ Transformer block (MSA + FFN).
132
+ """
133
+ def __init__(
134
+ self,
135
+ channels: int,
136
+ num_heads: int,
137
+ out_channels: Optional[int] = None,
138
+ mlp_ratio: float = 4.0,
139
+ use_checkpoint: bool = False,
140
+ qkv_bias: bool = True,
141
+ ln_affine: bool = False,
142
+ ):
143
+ super().__init__()
144
+ self.use_checkpoint = use_checkpoint
145
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
146
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
147
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
148
+ self.to_v = nn.Linear(channels, channels, bias=qkv_bias)
149
+ self.channels = channels
150
+ self.out_channels = out_channels if out_channels is not None else channels
151
+ self.mlp = FeedForwardNet(
152
+ self.channels,
153
+ out_channels=self.out_channels,
154
+ mlp_ratio=mlp_ratio,
155
+ )
156
+ self.joint_attn = MultiHeadAttention(
157
+ channels,
158
+ num_heads=num_heads,
159
+ type="self",
160
+ attn_mode="full",
161
+ qkv_bias=qkv_bias,
162
+ )
163
+ self.joint_mlp = FeedForwardNet(
164
+ self.channels,
165
+ out_channels=self.out_channels,
166
+ mlp_ratio=mlp_ratio,
167
+ )
168
+ if self.out_channels != self.channels:
169
+ self.res_mlp = FeedForwardNet(
170
+ self.channels,
171
+ out_channels=self.out_channels,
172
+ mlp_ratio=1.0,
173
+ )
174
+ self.res_joint_mlp = FeedForwardNet(
175
+ self.channels,
176
+ out_channels=self.out_channels,
177
+ mlp_ratio=1.0,
178
+ )
179
+
180
+ def _forward(self, x: torch.Tensor, j: torch.Tensor, skin: torch.Tensor) -> torch.Tensor:
181
+ v = self.to_v(self.norm1(j))
182
+ h = skin @ v
183
+ x = x + h
184
+ h = self.norm2(x)
185
+ h = self.mlp(h)
186
+ if self.out_channels != self.channels:
187
+ x = self.res_mlp(x)
188
+ x = x + h
189
+
190
+ h_j = self.norm3(j)
191
+ h_j = self.joint_attn(h_j)
192
+ h_j = j + h_j
193
+ h_j = self.joint_mlp(h_j)
194
+ if self.out_channels != self.channels:
195
+ j = self.res_joint_mlp(j)
196
+ j = j + h_j
197
+ return x, j
198
+
199
+
200
+ class TransformerCrossBlock(nn.Module):
201
+ """
202
+ Transformer cross-attention block (MSA + MCA + FFN).
203
+ """
204
+ def __init__(
205
+ self,
206
+ channels: int,
207
+ ctx_channels: int,
208
+ num_heads: int,
209
+ out_channels: Optional[int] = None,
210
+ mlp_ratio: float = 4.0,
211
+ attn_mode: Literal["full", "windowed"] = "full",
212
+ window_size: Optional[int] = None,
213
+ shift_window: Optional[Tuple[int, int, int]] = None,
214
+ use_checkpoint: bool = False,
215
+ use_rope: bool = False,
216
+ qk_rms_norm: bool = False,
217
+ qk_rms_norm_cross: bool = False,
218
+ qkv_bias: bool = True,
219
+ ln_affine: bool = False,
220
+ x_is_query: bool = False,
221
+ no_self: bool = False,
222
+ ):
223
+ super().__init__()
224
+ self.use_checkpoint = use_checkpoint
225
+ self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) if not no_self else nn.Identity()
226
+ self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
227
+ self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
228
+ if no_self:
229
+ self.self_attn = lambda x: 0
230
+ else:
231
+ self.self_attn = MultiHeadAttention(
232
+ channels,
233
+ num_heads=num_heads,
234
+ type="self",
235
+ attn_mode=attn_mode,
236
+ window_size=window_size,
237
+ shift_window=shift_window,
238
+ qkv_bias=qkv_bias,
239
+ use_rope=use_rope,
240
+ qk_rms_norm=qk_rms_norm,
241
+ )
242
+ self.cross_attn = MultiHeadAttention(
243
+ channels,
244
+ ctx_channels=ctx_channels,
245
+ num_heads=num_heads,
246
+ type="cross",
247
+ attn_mode="full",
248
+ qkv_bias=qkv_bias,
249
+ qk_rms_norm=qk_rms_norm_cross,
250
+ x_is_query=x_is_query,
251
+ )
252
+ self.channels = channels
253
+ self.out_channels = out_channels if out_channels is not None else channels
254
+ self.mlp = FeedForwardNet(
255
+ channels,
256
+ out_channels=self.out_channels,
257
+ mlp_ratio=mlp_ratio,
258
+ )
259
+ if self.out_channels != self.channels:
260
+ self.res_mlp = FeedForwardNet(
261
+ self.channels,
262
+ out_channels=self.out_channels,
263
+ mlp_ratio=1.0,
264
+ )
265
+
266
+ def _forward(self, x: torch.Tensor, context: torch.Tensor):
267
+ h = self.norm1(x)
268
+ h = self.self_attn(h)
269
+ x = x + h
270
+ h = self.norm2(x)
271
+ h = self.cross_attn(h, context)
272
+ x = x + h
273
+ h = self.norm3(x)
274
+ h = self.mlp(h)
275
+ if self.out_channels != self.channels:
276
+ x = self.res_mlp(x)
277
+ x = x + h
278
+ return x
279
+
280
+ def forward(self, x: torch.Tensor, context: torch.Tensor):
281
+ if self.use_checkpoint:
282
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
283
+ else:
284
+ return self._forward(x, context)
285
+
anigen/modules/transformer/modulated.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch
3
+ import torch.nn as nn
4
+ from ..attention import MultiHeadAttention
5
+ from ..norm import LayerNorm32
6
+ from .blocks import FeedForwardNet
7
+
8
+
9
+ class ModulatedTransformerBlock(nn.Module):
10
+ """
11
+ Transformer block (MSA + FFN) with adaptive layer norm conditioning.
12
+ """
13
+ def __init__(
14
+ self,
15
+ channels: int,
16
+ num_heads: int,
17
+ mlp_ratio: float = 4.0,
18
+ attn_mode: Literal["full", "windowed"] = "full",
19
+ window_size: Optional[int] = None,
20
+ shift_window: Optional[Tuple[int, int, int]] = None,
21
+ use_checkpoint: bool = False,
22
+ use_rope: bool = False,
23
+ qk_rms_norm: bool = False,
24
+ qkv_bias: bool = True,
25
+ share_mod: bool = False,
26
+ ):
27
+ super().__init__()
28
+ self.use_checkpoint = use_checkpoint
29
+ self.share_mod = share_mod
30
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
31
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
32
+ self.attn = MultiHeadAttention(
33
+ channels,
34
+ num_heads=num_heads,
35
+ attn_mode=attn_mode,
36
+ window_size=window_size,
37
+ shift_window=shift_window,
38
+ qkv_bias=qkv_bias,
39
+ use_rope=use_rope,
40
+ qk_rms_norm=qk_rms_norm,
41
+ )
42
+ self.mlp = FeedForwardNet(
43
+ channels,
44
+ mlp_ratio=mlp_ratio,
45
+ )
46
+ if not share_mod:
47
+ self.adaLN_modulation = nn.Sequential(
48
+ nn.SiLU(),
49
+ nn.Linear(channels, 6 * channels, bias=True)
50
+ )
51
+
52
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
53
+ if self.share_mod:
54
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
55
+ else:
56
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
57
+ h = self.norm1(x)
58
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
59
+ h = self.attn(h)
60
+ h = h * gate_msa.unsqueeze(1)
61
+ x = x + h
62
+ h = self.norm2(x)
63
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
64
+ h = self.mlp(h)
65
+ h = h * gate_mlp.unsqueeze(1)
66
+ x = x + h
67
+ return x
68
+
69
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
70
+ if self.use_checkpoint:
71
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
72
+ else:
73
+ return self._forward(x, mod)
74
+
75
+
76
+ class ModulatedTransformerCrossBlock(nn.Module):
77
+ """
78
+ Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
79
+ """
80
+ def __init__(
81
+ self,
82
+ channels: int,
83
+ ctx_channels: int,
84
+ num_heads: int,
85
+ mlp_ratio: float = 4.0,
86
+ attn_mode: Literal["full", "windowed"] = "full",
87
+ window_size: Optional[int] = None,
88
+ shift_window: Optional[Tuple[int, int, int]] = None,
89
+ use_checkpoint: bool = False,
90
+ use_rope: bool = False,
91
+ qk_rms_norm: bool = False,
92
+ qk_rms_norm_cross: bool = False,
93
+ qkv_bias: bool = True,
94
+ share_mod: bool = False,
95
+
96
+ use_lora_self: bool = False,
97
+ lora_rank_self: int = 4,
98
+ use_lora_cross: bool = False,
99
+ lora_rank_cross: int = 4,
100
+ lora_lr_rate: float = 1.0,
101
+ use_context_norm: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.use_checkpoint = use_checkpoint
105
+ self.share_mod = share_mod
106
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
107
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
108
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
109
+ self.use_context_norm = use_context_norm
110
+ if self.use_context_norm:
111
+ self.context_norm = LayerNorm32(ctx_channels, elementwise_affine=True, eps=1e-6)
112
+ self.self_attn = MultiHeadAttention(
113
+ channels,
114
+ num_heads=num_heads,
115
+ type="self",
116
+ attn_mode=attn_mode,
117
+ window_size=window_size,
118
+ shift_window=shift_window,
119
+ qkv_bias=qkv_bias,
120
+ use_rope=use_rope,
121
+ qk_rms_norm=qk_rms_norm,
122
+ use_lora=use_lora_self,
123
+ lora_rank=lora_rank_self,
124
+ lora_lr_rate=lora_lr_rate,
125
+ )
126
+ self.cross_attn = MultiHeadAttention(
127
+ channels,
128
+ ctx_channels=ctx_channels,
129
+ num_heads=num_heads,
130
+ type="cross",
131
+ attn_mode="full",
132
+ qkv_bias=qkv_bias,
133
+ qk_rms_norm=qk_rms_norm_cross,
134
+ use_lora=use_lora_cross,
135
+ lora_rank=lora_rank_cross,
136
+ lora_lr_rate=lora_lr_rate,
137
+ )
138
+ self.mlp = FeedForwardNet(
139
+ channels,
140
+ mlp_ratio=mlp_ratio,
141
+ )
142
+ if not share_mod:
143
+ self.adaLN_modulation = nn.Sequential(
144
+ nn.SiLU(),
145
+ nn.Linear(channels, 6 * channels, bias=True)
146
+ )
147
+
148
+ def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
149
+ if self.share_mod:
150
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
151
+ else:
152
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
153
+ h = self.norm1(x)
154
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
155
+ h = self.self_attn(h)
156
+ h = h * gate_msa.unsqueeze(1)
157
+ x = x + h
158
+ h = self.norm2(x)
159
+ if self.use_context_norm:
160
+ context = self.context_norm(context)
161
+ h = self.cross_attn(h, context)
162
+ x = x + h
163
+ h = self.norm3(x)
164
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
165
+ h = self.mlp(h)
166
+ h = h * gate_mlp.unsqueeze(1)
167
+ x = x + h
168
+ return x
169
+
170
+ def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
171
+ if self.use_checkpoint:
172
+ return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
173
+ else:
174
+ return self._forward(x, mod, context)
175
+
anigen/modules/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..modules import sparse as sp
3
+
4
+ FP16_MODULES = (
5
+ nn.Conv1d,
6
+ nn.Conv2d,
7
+ nn.Conv3d,
8
+ nn.ConvTranspose1d,
9
+ nn.ConvTranspose2d,
10
+ nn.ConvTranspose3d,
11
+ nn.Linear,
12
+ sp.SparseConv3d,
13
+ sp.SparseInverseConv3d,
14
+ sp.SparseLinear,
15
+ )
16
+
17
+ def convert_module_to_f16(l):
18
+ """
19
+ Convert primitive modules to float16.
20
+ """
21
+ if isinstance(l, FP16_MODULES):
22
+ for p in l.parameters():
23
+ p.data = p.data.half()
24
+
25
+
26
+ def convert_module_to_f32(l):
27
+ """
28
+ Convert primitive modules to float32, undoing convert_module_to_f16().
29
+ """
30
+ if isinstance(l, FP16_MODULES):
31
+ for p in l.parameters():
32
+ p.data = p.data.float()
33
+
34
+
35
+ def zero_module(module):
36
+ """
37
+ Zero out the parameters of a module and return it.
38
+ """
39
+ for p in module.parameters():
40
+ p.detach().zero_()
41
+ return module
42
+
43
+
44
+ def scale_module(module, scale):
45
+ """
46
+ Scale the parameters of a module and return it.
47
+ """
48
+ for p in module.parameters():
49
+ p.detach().mul_(scale)
50
+ return module
51
+
52
+
53
+ def modulate(x, shift, scale):
54
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)