Commit ·
6b92ff7
0
Parent(s):
Initial commit: AniGen - Animatable 3D Generation
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +58 -0
- Dockerfile +63 -0
- README.md +231 -0
- THIRD_PARTY_LICENSES.md +30 -0
- anigen/__init__.py +6 -0
- anigen/datasets/__init__.py +32 -0
- anigen/datasets/anigen_sparse_feat2skeleton.py +290 -0
- anigen/datasets/anigen_sparse_structure.py +124 -0
- anigen/datasets/anigen_sparse_structure_latent.py +238 -0
- anigen/datasets/anigen_structured_latent.py +327 -0
- anigen/datasets/components.py +143 -0
- anigen/models/__init__.py +67 -0
- anigen/models/anigen_sparse_structure_flow.py +487 -0
- anigen/models/anigen_sparse_structure_vae.py +729 -0
- anigen/models/anigen_structured_latent_flow.py +553 -0
- anigen/models/sparse_elastic_mixin.py +24 -0
- anigen/models/structured_latent_vae/__init__.py +3 -0
- anigen/models/structured_latent_vae/anigen_base.py +256 -0
- anigen/models/structured_latent_vae/anigen_decoder.py +834 -0
- anigen/models/structured_latent_vae/anigen_encoder.py +318 -0
- anigen/models/structured_latent_vae/base.py +117 -0
- anigen/models/structured_latent_vae/skin_models.py +252 -0
- anigen/modules/attention/__init__.py +36 -0
- anigen/modules/attention/full_attn.py +140 -0
- anigen/modules/attention/modules.py +161 -0
- anigen/modules/norm.py +25 -0
- anigen/modules/sparse/__init__.py +102 -0
- anigen/modules/sparse/attention/__init__.py +5 -0
- anigen/modules/sparse/attention/full_attn.py +215 -0
- anigen/modules/sparse/attention/modules.py +151 -0
- anigen/modules/sparse/attention/serialized_attn.py +193 -0
- anigen/modules/sparse/attention/windowed_attn.py +135 -0
- anigen/modules/sparse/attention/windowed_attn_cross.py +131 -0
- anigen/modules/sparse/basic.py +465 -0
- anigen/modules/sparse/conv/__init__.py +21 -0
- anigen/modules/sparse/conv/conv_spconv.py +80 -0
- anigen/modules/sparse/conv/conv_torchsparse.py +38 -0
- anigen/modules/sparse/linear.py +15 -0
- anigen/modules/sparse/nonlinearity.py +35 -0
- anigen/modules/sparse/norm.py +58 -0
- anigen/modules/sparse/spatial.py +110 -0
- anigen/modules/sparse/transformer/__init__.py +3 -0
- anigen/modules/sparse/transformer/anigen_modulated.py +155 -0
- anigen/modules/sparse/transformer/blocks.py +259 -0
- anigen/modules/sparse/transformer/modulated.py +174 -0
- anigen/modules/spatial.py +48 -0
- anigen/modules/transformer/__init__.py +2 -0
- anigen/modules/transformer/blocks.py +285 -0
- anigen/modules/transformer/modulated.py +175 -0
- anigen/modules/utils.py +54 -0
.gitattributes
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
anigen/representations/mesh/flexicubes/images/block_init.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
anigen/representations/mesh/flexicubes/images/teaser_top.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/cond_images/dog.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/cond_images/lamp.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/cond_images/machine_arm.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/cond_images/owl.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/cond_images/trex.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/cond_images/whale.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/images/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/cond_images/machine_dog.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/cond_images/spongebob.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/gifs/eagle.gif filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/gifs/evo.gif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
assets/gifs/horse.gif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
assets/gifs/iron_boy.gif filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/gifs/machine_arm.gif filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
assets/gifs/machine_dog.gif filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
assets/gifs/mairo.gif filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
assets/gifs/money_tree.gif filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
assets/cond_images/brickbob.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
assets/cond_images/bruno_star.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
assets/cond_images/evo.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
assets/cond_images/iron_boy.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
python3.10 python3-pip python3.10-dev \
|
| 7 |
+
git git-lfs ffmpeg libsm6 libxext6 libgl1 libegl1 \
|
| 8 |
+
build-essential ninja-build cmake rsync \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 10 |
+
&& git lfs install
|
| 11 |
+
|
| 12 |
+
RUN useradd -m -u 1000 user
|
| 13 |
+
USER user
|
| 14 |
+
|
| 15 |
+
ENV HOME=/home/user \
|
| 16 |
+
PATH=/home/user/.local/bin:$PATH \
|
| 17 |
+
PYTHONUNBUFFERED=1 \
|
| 18 |
+
PIP_NO_CACHE_DIR=1 \
|
| 19 |
+
HF_HOME=/home/user/.cache/huggingface \
|
| 20 |
+
TORCH_HOME=/home/user/.cache/torch \
|
| 21 |
+
ATTN_BACKEND=xformers \
|
| 22 |
+
SPARSE_ATTN_BACKEND=xformers \
|
| 23 |
+
TORCH_CUDA_ARCH_LIST="7.5;8.6;8.9"
|
| 24 |
+
|
| 25 |
+
WORKDIR $HOME/app
|
| 26 |
+
|
| 27 |
+
COPY --chown=user:user . $HOME/app
|
| 28 |
+
|
| 29 |
+
RUN python3.10 -m pip install --upgrade pip setuptools wheel
|
| 30 |
+
|
| 31 |
+
RUN python3.10 -m pip install \
|
| 32 |
+
torch==2.4.0 torchvision==0.19.0 \
|
| 33 |
+
--index-url https://download.pytorch.org/whl/cu121
|
| 34 |
+
|
| 35 |
+
RUN python3.10 -m pip install \
|
| 36 |
+
pillow imageio imageio-ffmpeg tqdm easydict scipy ninja psutil safetensors \
|
| 37 |
+
scikit-learn opencv-python-headless rembg onnxruntime \
|
| 38 |
+
trimesh xatlas pyvista pymeshfix igraph pygltflib geffnet \
|
| 39 |
+
transformers \
|
| 40 |
+
gradio==4.44.1 gradio_litmodel3d==0.0.1 "huggingface_hub<0.25"
|
| 41 |
+
|
| 42 |
+
RUN python3.10 -m pip install \
|
| 43 |
+
git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
|
| 44 |
+
|
| 45 |
+
RUN python3.10 -m pip install \
|
| 46 |
+
"pytorch3d==0.7.8" \
|
| 47 |
+
--find-links https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html
|
| 48 |
+
|
| 49 |
+
RUN python3.10 -m pip install \
|
| 50 |
+
xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121
|
| 51 |
+
|
| 52 |
+
RUN python3.10 -m pip install spconv-cu121
|
| 53 |
+
|
| 54 |
+
RUN python3.10 -m pip install \
|
| 55 |
+
kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html
|
| 56 |
+
|
| 57 |
+
RUN python3.10 -m pip install \
|
| 58 |
+
"git+https://github.com/NVlabs/nvdiffrast.git" --no-build-isolation
|
| 59 |
+
|
| 60 |
+
EXPOSE 7860
|
| 61 |
+
|
| 62 |
+
CMD ["python3.10", "app.py"]
|
| 63 |
+
|
README.md
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AniGen
|
| 3 |
+
sdk: gradio
|
| 4 |
+
sdk_version: 4.44.1
|
| 5 |
+
python_version: 3.10.13
|
| 6 |
+
startup_duration_timeout: 2h
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
<h1 align="center">AniGen: Unified S<sup>3</sup> Fields for Animatable 3D Asset Generation</h1>
|
| 10 |
+
<p align="center"><a href="https://arxiv.org/pdf/2604.08746"><img src='https://img.shields.io/badge/arXiv-Paper-red?logo=arxiv&logoColor=white' alt='arXiv'></a>
|
| 11 |
+
<a href='https://yihua7.github.io/AniGen_web/'><img src='https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white' alt='Project Page'></a>
|
| 12 |
+
<a href='https://huggingface.co/spaces/VAST-AI/AniGen'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Live_Demo-blue'></a>
|
| 13 |
+
<a href="https://www.tripo3d.ai"><img src="https://img.shields.io/badge/Tripo-AI_3D_Workspace-orange" alt="Tripo"></a>
|
| 14 |
+
</p>
|
| 15 |
+
<p align="center"><img src="assets/images/teaser.png" width="100%"></p>
|
| 16 |
+
|
| 17 |
+
<span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> is a unified framework that directly generates animate-ready 3D assets conditioned on a single image. Our key insight is to represent shape, skeleton, and skinning as mutually consistent *$S^3$ Fields* (Shape, Skeleton, Skin) defined over a shared spatial domain. To enable the robust learning of these fields, we introduce two technical innovations: (i) a *confidence-decaying skeleton field* that explicitly handles the geometric ambiguity of bone prediction at Voronoi boundaries, and (ii) a *dual skin feature field* that decouples skinning weights from specific joint counts, allowing a fixed-architecture network to predict rigs of arbitrary complexity. Built upon a two-stage flow-matching pipeline, <span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> first synthesizes a sparse structural scaffold and then generates dense geometry and articulation in a structured latent space. Extensive experiments demonstrate that <span style="font-size: 16px; font-weight: 600;">A</span><span style="font-size: 12px; font-weight: 700;">niGen</span> substantially outperforms state-of-the-art sequential baselines in rig validity and animation quality, generalizing effectively to in-the-wild images across diverse categories including animals, humanoids, and machinery.
|
| 18 |
+
|
| 19 |
+
<!-- Overview -->
|
| 20 |
+
## 🔮 Overview
|
| 21 |
+
|
| 22 |
+
AniGen takes a **single image** as input and automatically produces a fully rigged, animate-ready 3D asset, complete with a coherent mesh, an articulated skeleton, and smooth skinning weights. The generated assets can be directly imported into standard 3D pipelines and driven by off-the-shelf motion data, enabling immediate deployment across a wide spectrum of downstream applications, including **embodied AI** agent construction, **physics-based simulation**, **character animation**, **dynamic scene creation**, and **articulated object manipulation**.
|
| 23 |
+
|
| 24 |
+
<table width="100%">
|
| 25 |
+
<tr>
|
| 26 |
+
<td width="25%" align="center"><img src="assets/gifs/machine_arm.gif" width="100%"><br><b>Machine Arm</b></td>
|
| 27 |
+
<td width="25%" align="center"><img src="assets/gifs/machine_dog.gif" width="100%"><br><b>Machine Dog</b></td>
|
| 28 |
+
<td width="25%" align="center"><img src="assets/gifs/money_tree.gif" width="100%"><br><b>Money Tree</b></td>
|
| 29 |
+
<td width="25%" align="center"><img src="assets/gifs/iron_boy.gif" width="100%"><br><b>Iron Boy</b></td>
|
| 30 |
+
</tr>
|
| 31 |
+
<tr>
|
| 32 |
+
<td width="25%" align="center"><img src="assets/gifs/mairo.gif" width="100%"><br><b>Mairo</b></td>
|
| 33 |
+
<td width="25%" align="center"><img src="assets/gifs/evo.gif" width="100%"><br><b>Evo</b></td>
|
| 34 |
+
<td width="25%" align="center"><img src="assets/gifs/horse.gif" width="100%"><br><b>Horse</b></td>
|
| 35 |
+
<td width="25%" align="center"><img src="assets/gifs/eagle.gif" width="100%"><br><b>Eagle</b></td>
|
| 36 |
+
</tr>
|
| 37 |
+
</table>
|
| 38 |
+
|
| 39 |
+
<!-- Installation -->
|
| 40 |
+
## 📦 Installation
|
| 41 |
+
|
| 42 |
+
### Prerequisites
|
| 43 |
+
- **System**: The code is currently tested only on **Linux**.
|
| 44 |
+
- **Hardware**: An NVIDIA GPU with at least 18GB of memory is necessary. The code has been verified on NVIDIA A800 and RTX3090 GPUs.
|
| 45 |
+
- **Software**:
|
| 46 |
+
- The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain submodules. The code has been tested with CUDA versions 11.8 and 12.2.
|
| 47 |
+
- [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is recommended for managing dependencies.
|
| 48 |
+
- Python version 3.8 or higher is required.
|
| 49 |
+
|
| 50 |
+
### Installation Steps
|
| 51 |
+
1. Clone the repo:
|
| 52 |
+
```sh
|
| 53 |
+
git clone --recurse-submodules https://github.com/VAST-AI-Research/AniGen.git
|
| 54 |
+
cd AniGen
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
2. Install the dependencies:
|
| 58 |
+
|
| 59 |
+
We recommend using [uv](https://docs.astral.sh/uv/) for fast, reliable installs. The setup script will also work with plain `pip` if `uv` is not available.
|
| 60 |
+
|
| 61 |
+
Create a new virtual environment and install everything:
|
| 62 |
+
```sh
|
| 63 |
+
source ./setup.sh --new-env --all
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
If your network connection to PyPI is unstable or slow, you can use the Tsinghua mirror:
|
| 67 |
+
```sh
|
| 68 |
+
source ./setup.sh --new-env --all --tsinghua
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
If you already have an environment with PyTorch installed, install into it directly:
|
| 72 |
+
```sh
|
| 73 |
+
source ./setup.sh --basic
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
> [!NOTE]
|
| 77 |
+
> The setup script auto-detects your CUDA version and installs matching wheels for PyTorch, spconv, pytorch3d, and nvdiffrast. [DSINE](https://github.com/baegwangbin/DSINE) (used for surface normal estimation) is loaded at runtime via `torch.hub` and does not require separate installation.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
<!-- Pretrained Models -->
|
| 81 |
+
## 🤖 Pretrained Models
|
| 82 |
+
|
| 83 |
+
We provide the following pretrained models on [Hugging Face](https://huggingface.co/VAST-AI/AniGen/tree/main). Please make sure to download all necessary weights from this page, including the required dinov2, dsine, and vgg checkpoints.
|
| 84 |
+
|
| 85 |
+
> [!TIP]
|
| 86 |
+
> **Recommended:** Use **SS-Flow-Duet** + **SLAT-Flow-Auto** if you do not have specific requirements.
|
| 87 |
+
> - For more detailed skeleton (including character fingers) → **SS-Flow-Duet**
|
| 88 |
+
> - For better geometry generalization → **SS-Flow-Solo**
|
| 89 |
+
> - **SLAT-Flow-Control** supports density levels 0–4, but if the density condition significantly deviates from the proper value for the object, skinning weights may be damaged.
|
| 90 |
+
|
| 91 |
+
| DAE Model | Description | Download |
|
| 92 |
+
| --- | --- | --- |
|
| 93 |
+
| SS-DAE | Encoder&Decoder of SS | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_dae) |
|
| 94 |
+
| SLAT-DAE | Encoder&Decoder of SLAT | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_dae) |
|
| 95 |
+
|
| 96 |
+
| SS Model | Description | Download |
|
| 97 |
+
| --- | --- | --- |
|
| 98 |
+
| SS-Flow-Duet | Detailed Skeleton (Full-FT Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_duet) |
|
| 99 |
+
| SS-Flow-Epic | Geometry&Skeleton Balanced (LoRA-FT Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_epic) |
|
| 100 |
+
| SS-Flow-Solo | Accurate Geometry (Freeze Geo) | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/ss_flow_solo) |
|
| 101 |
+
|
| 102 |
+
| SLAT Model | Description | Download |
|
| 103 |
+
| --- | --- | --- |
|
| 104 |
+
| SLAT-Flow-Auto | Automatically Determine Joint Number | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_flow_auto) |
|
| 105 |
+
| SLAT-Flow-Control | Controllable Joint Density | [Download](https://huggingface.co/VAST-AI/AniGen/tree/main/ckpts/anigen/slat_flow_control) |
|
| 106 |
+
|
| 107 |
+
<!-- Usage -->
|
| 108 |
+
## 💡 Usage
|
| 109 |
+
|
| 110 |
+
### Minimal Example
|
| 111 |
+
|
| 112 |
+
Here is an [example](example.py) of how to use the pretrained models for 3D asset generation.
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
After running the code, you will get the following files:
|
| 116 |
+
- `mesh.glb`: a rigged mesh file
|
| 117 |
+
- `skeleton.glb`: a skeleton visualization file
|
| 118 |
+
- `processed_image.png`: the masked image as the condition
|
| 119 |
+
|
| 120 |
+
### AniGen Pipeline (Rigged Mesh + Skeleton)
|
| 121 |
+
|
| 122 |
+
For AniGen checkpoints in this repo (e.g. `ckpts/anigen/ss_flow_solo` + `ckpts/anigen/slat_flow_control`), you can run:
|
| 123 |
+
```sh
|
| 124 |
+
python example.py --image_path assets/cond_images/trex.png
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Web Demo
|
| 128 |
+
|
| 129 |
+
[app.py](app.py) provides a simple web demo for 3D asset generation. Since this demo is based on [Gradio](https://gradio.app/), additional dependencies are required:
|
| 130 |
+
```sh
|
| 131 |
+
source ./setup.sh --demo
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
If needed, you can also install the demo dependencies via the Tsinghua mirror:
|
| 135 |
+
```sh
|
| 136 |
+
source ./setup.sh --demo --tsinghua
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
After installing the dependencies, you can run the demo with the following command:
|
| 140 |
+
```sh
|
| 141 |
+
python app.py
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Then, you can access the demo at the address shown in the terminal.
|
| 145 |
+
|
| 146 |
+
***The web demo is also available on [Hugging Face Spaces](https://huggingface.co/spaces/VAST-AI/AniGen)!***
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
<!-- Training -->
|
| 150 |
+
## 🏋️ Training
|
| 151 |
+
|
| 152 |
+
### Training Data
|
| 153 |
+
|
| 154 |
+
Sample training data is available at [AniGen_sample_data](https://huggingface.co/datasets/VAST-AI/AniGen_sample_data). To prepare your own data, refer to [TRELLIS](https://github.com/microsoft/TRELLIS) and the sample data format.
|
| 155 |
+
|
| 156 |
+
### Prerequisites
|
| 157 |
+
|
| 158 |
+
> [!NOTE]
|
| 159 |
+
> Training requires the **CUBVH** extension (`extensions/CUBVH/`), which is automatically built by `setup.sh`. It is **not** needed for inference (`app.py`, `example.py`).
|
| 160 |
+
|
| 161 |
+
### Training Commands
|
| 162 |
+
|
| 163 |
+
The pipeline has five stages. Later stages depend on earlier ones, so please train in order:
|
| 164 |
+
|
| 165 |
+
```sh
|
| 166 |
+
# Stage 1: Skin AutoEncoder
|
| 167 |
+
python train.py --config configs/anigen_skin_ae.json --output_dir outputs/anigen_skin_ae
|
| 168 |
+
|
| 169 |
+
# Stage 2: Sparse Structure DAE
|
| 170 |
+
python train.py --config configs/ss_dae.json --output_dir outputs/ss_dae
|
| 171 |
+
|
| 172 |
+
# Stage 3: Structured Latent DAE
|
| 173 |
+
python train.py --config configs/slat_dae.json --output_dir outputs/slat_dae
|
| 174 |
+
|
| 175 |
+
# Stage 4: SS Flow Matching (image-conditioned generation)
|
| 176 |
+
python train.py --config configs/ss_flow_duet.json --output_dir outputs/ss_flow_duet
|
| 177 |
+
|
| 178 |
+
# Stage 5: SLAT Flow Matching (image-conditioned generation)
|
| 179 |
+
python train.py --config configs/slat_flow_auto.json --output_dir outputs/slat_flow_auto
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### Multi-Node / Multi-GPU
|
| 183 |
+
|
| 184 |
+
Append the following flags for distributed training across multiple machines and GPUs:
|
| 185 |
+
|
| 186 |
+
```sh
|
| 187 |
+
python train.py --config configs/<config>.json --output_dir outputs/<output> \
|
| 188 |
+
--num_nodes XX --node_rank XX --master_addr XX --master_port XX
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### Model Variants
|
| 192 |
+
|
| 193 |
+
Other SS Flow variants (`ss_flow_epic`, `ss_flow_solo`) and SLAT Flow variants (`slat_flow_control`, `slat_flow_gsn_auto`) are available under `ckpts/anigen/`. Their config files can be found at `ckpts/anigen/<variant>/config.json`.
|
| 194 |
+
|
| 195 |
+
### Resume / Restart
|
| 196 |
+
|
| 197 |
+
Training automatically resumes from the latest checkpoint in `--output_dir`. To start fresh, pass `--ckpt none`.
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
## License
|
| 201 |
+
|
| 202 |
+
This project's source code is released under the [MIT License](LICENSE).
|
| 203 |
+
|
| 204 |
+
> [!IMPORTANT]
|
| 205 |
+
> This repository includes third-party components with additional license restrictions. In particular, `extensions/CUBVH/` contains BVH code derived from NVIDIA's [instant-ngp](https://github.com/NVlabs/instant-ngp), which is licensed for **non-commercial / research use only**. See [THIRD_PARTY_LICENSES.md](THIRD_PARTY_LICENSES.md) for details.
|
| 206 |
+
|
| 207 |
+
## Acknowledgements
|
| 208 |
+
|
| 209 |
+
- [TRELLIS](https://github.com/microsoft/TRELLIS) by Microsoft
|
| 210 |
+
- [cuBVH](https://github.com/ashawkey/cubvh) by Jiaxiang Tang
|
| 211 |
+
- [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) and [instant-ngp](https://github.com/NVlabs/instant-ngp) by Thomas Müller / NVIDIA
|
| 212 |
+
- [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) by NVIDIA
|
| 213 |
+
|
| 214 |
+
We sincerely appreciate the contributions of these excellent projects and their authors. We believe open source helps accelerate research, lower barriers to innovation, and make progress more accessible to the broader community.
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
<!-- Citation -->
|
| 218 |
+
## 📜 Citation
|
| 219 |
+
|
| 220 |
+
If you find this work helpful, please consider citing our paper:
|
| 221 |
+
|
| 222 |
+
```bibtex
|
| 223 |
+
@article{huang2026anigen,
|
| 224 |
+
title = {AniGen: Unified $S^3$ Fields for Animatable 3D Asset Generation},
|
| 225 |
+
author = {Huang, Yi-Hua and Zhou, Zi-Xin and He, Yuting and Chang, Chirui
|
| 226 |
+
and Pu, Cheng-Feng and Yang, Ziyi and Guo, Yuan-Chen
|
| 227 |
+
and Cao, Yan-Pei and Qi, Xiaojuan},
|
| 228 |
+
journal = {ACM SIGGRAPH},
|
| 229 |
+
year = {2026}
|
| 230 |
+
}
|
| 231 |
+
```
|
THIRD_PARTY_LICENSES.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Third-Party Licenses
|
| 2 |
+
|
| 3 |
+
## extensions/CUBVH/ — cuBVH (CUDA Mesh BVH Acceleration)
|
| 4 |
+
|
| 5 |
+
Originally created by [Jiaxiang Tang (ashawkey)](https://github.com/ashawkey/cubvh),
|
| 6 |
+
modified by Yi-Hua Huang (yihua7).
|
| 7 |
+
|
| 8 |
+
### MIT License (cubvh overall)
|
| 9 |
+
- File: `extensions/CUBVH/LICENSE`
|
| 10 |
+
- Copyright (c) 2022 Jiaxiang Tang (ashawkey)
|
| 11 |
+
- Copyright (c) 2025 Yi-Hua Huang (yihua7)
|
| 12 |
+
|
| 13 |
+
### NVIDIA Source Code License — Non-Commercial (BVH from instant-ngp)
|
| 14 |
+
- File: `extensions/CUBVH/LICENSE_NVIDIA`
|
| 15 |
+
- Copyright (c) 2022, NVIDIA Corporation & affiliates
|
| 16 |
+
- **USE RESTRICTED TO NON-COMMERCIAL / RESEARCH PURPOSES ONLY**
|
| 17 |
+
|
| 18 |
+
### BSD 3-Clause License (gpu_memory.h from tiny-cuda-nn)
|
| 19 |
+
- File header: `extensions/CUBVH/include/gpu/gpu_memory.h`
|
| 20 |
+
- Copyright (c) 2020-2022, NVIDIA CORPORATION
|
| 21 |
+
|
| 22 |
+
### Apache License 2.0 (pcg32.h)
|
| 23 |
+
- File header: `extensions/CUBVH/include/gpu/pcg32.h`
|
| 24 |
+
- Author: Wenzel Jakob, modified by tiny-cuda-nn
|
| 25 |
+
|
| 26 |
+
## anigen/representations/mesh/flexicubes/ — FlexiCubes
|
| 27 |
+
|
| 28 |
+
- File: `anigen/representations/mesh/flexicubes/LICENSE.txt`
|
| 29 |
+
- Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES
|
| 30 |
+
- Apache License 2.0
|
anigen/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import models
|
| 2 |
+
from . import modules
|
| 3 |
+
from . import pipelines
|
| 4 |
+
from . import renderers
|
| 5 |
+
from . import representations
|
| 6 |
+
from . import utils
|
anigen/datasets/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
__attributes = {
|
| 4 |
+
'AniGenSparseStructure': 'anigen_sparse_structure',
|
| 5 |
+
'AniGenSparseFeat2Skeleton': 'anigen_sparse_feat2skeleton',
|
| 6 |
+
'AniGenSparseFeat2Render': 'anigen_sparse_feat2render',
|
| 7 |
+
|
| 8 |
+
'AniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
|
| 9 |
+
'TextConditionedAniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
|
| 10 |
+
'ImageConditionedAniGenSparseStructureLatent': 'anigen_sparse_structure_latent',
|
| 11 |
+
|
| 12 |
+
'AniGenSLat': 'anigen_structured_latent',
|
| 13 |
+
'AniGenTextConditionedSLat': 'anigen_structured_latent',
|
| 14 |
+
'AniGenImageConditionedSLat': 'anigen_structured_latent',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
__submodules = []
|
| 18 |
+
|
| 19 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 20 |
+
|
| 21 |
+
def __getattr__(name):
|
| 22 |
+
if name not in globals():
|
| 23 |
+
if name in __attributes:
|
| 24 |
+
module_name = __attributes[name]
|
| 25 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 26 |
+
globals()[name] = getattr(module, name)
|
| 27 |
+
elif name in __submodules:
|
| 28 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 29 |
+
globals()[name] = module
|
| 30 |
+
else:
|
| 31 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 32 |
+
return globals()[name]
|
anigen/datasets/anigen_sparse_feat2skeleton.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import utils3d.torch
|
| 8 |
+
from ..modules.sparse.basic import SparseTensor
|
| 9 |
+
from .components import StandardDatasetBase
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AniGenSparseFeat2Skeleton(StandardDatasetBase):
|
| 13 |
+
"""
|
| 14 |
+
SparseFeat2Render dataset.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
roots (str): paths to the dataset
|
| 18 |
+
image_size (int): size of the image
|
| 19 |
+
model (str): model name
|
| 20 |
+
resolution (int): resolution of the data
|
| 21 |
+
min_aesthetic_score (float): minimum aesthetic score
|
| 22 |
+
max_num_voxels (int): maximum number of voxels
|
| 23 |
+
"""
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
roots: str,
|
| 27 |
+
image_size: int,
|
| 28 |
+
model: str = 'dinov2_vitl14_reg',
|
| 29 |
+
resolution: int = 64,
|
| 30 |
+
min_aesthetic_score: float = 5.0,
|
| 31 |
+
max_num_voxels: int = 32768,
|
| 32 |
+
load_cubvh: bool = False,
|
| 33 |
+
skl_dilation_iter: int = 0,
|
| 34 |
+
skl_dilation_random_aug: bool = False,
|
| 35 |
+
skl_dilation_random_aug_prob: float = 0.5,
|
| 36 |
+
filter_bad_skin: bool = False,
|
| 37 |
+
|
| 38 |
+
test_mode: bool = True, # Test the model performance
|
| 39 |
+
is_test: bool = False, # Train or validation
|
| 40 |
+
skin_accum_as_flow: bool = False, # Accumulate skin weights from bottom to top as flow-by probability
|
| 41 |
+
local_rank: int = 0,
|
| 42 |
+
joint_merge_res: int = 64,
|
| 43 |
+
**kwargs,
|
| 44 |
+
):
|
| 45 |
+
self.image_size = image_size
|
| 46 |
+
self.model = model
|
| 47 |
+
self.resolution = resolution
|
| 48 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 49 |
+
self.max_num_voxels = max_num_voxels
|
| 50 |
+
self.value_range = (0, 1)
|
| 51 |
+
self.load_cubvh = load_cubvh
|
| 52 |
+
self.skl_dilation_iter = skl_dilation_iter
|
| 53 |
+
self.skl_dilation_random_aug = skl_dilation_random_aug
|
| 54 |
+
self.skl_dilation_random_aug_prob = skl_dilation_random_aug_prob
|
| 55 |
+
self.filter_bad_skin = filter_bad_skin
|
| 56 |
+
|
| 57 |
+
self.test_mode = test_mode
|
| 58 |
+
self.is_test = is_test
|
| 59 |
+
self.skin_accum_as_flow = skin_accum_as_flow
|
| 60 |
+
self.local_rank = local_rank
|
| 61 |
+
self.joint_merge_res = joint_merge_res
|
| 62 |
+
|
| 63 |
+
super().__init__(roots, **kwargs)
|
| 64 |
+
self.is_bad_skin_list = self.metadata['is_bad_skin'].values
|
| 65 |
+
|
| 66 |
+
def filter_metadata(self, metadata):
|
| 67 |
+
stats = {}
|
| 68 |
+
metadata = metadata[metadata[f'feature_{self.model}']]
|
| 69 |
+
stats['With features'] = len(metadata)
|
| 70 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 71 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 72 |
+
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
| 73 |
+
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
| 74 |
+
|
| 75 |
+
if 'is_bad_skeleton' in metadata.columns:
|
| 76 |
+
metadata = metadata[~metadata['is_bad_skeleton']]
|
| 77 |
+
if self.filter_bad_skin and 'is_bad_skin' in metadata.columns:
|
| 78 |
+
metadata = metadata[~metadata['is_bad_skin']]
|
| 79 |
+
|
| 80 |
+
if self.test_mode:
|
| 81 |
+
metadata = metadata[metadata['is_test']] if self.is_test else metadata[~metadata['is_test']]
|
| 82 |
+
|
| 83 |
+
return metadata, stats
|
| 84 |
+
|
| 85 |
+
def _get_image(self, root, instance):
|
| 86 |
+
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
|
| 87 |
+
metadata = json.load(f)
|
| 88 |
+
n_views = len(metadata['frames'])
|
| 89 |
+
view = np.random.randint(n_views)
|
| 90 |
+
metadata = metadata['frames'][view]
|
| 91 |
+
fov = metadata['camera_angle_x']
|
| 92 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
| 93 |
+
c2w = torch.tensor(metadata['transform_matrix'])
|
| 94 |
+
c2w[:3, 1:3] *= -1
|
| 95 |
+
extrinsics = torch.inverse(c2w)
|
| 96 |
+
|
| 97 |
+
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
|
| 98 |
+
image = Image.open(image_path)
|
| 99 |
+
alpha = image.getchannel(3)
|
| 100 |
+
image = image.convert('RGB')
|
| 101 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 102 |
+
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 103 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 104 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
| 105 |
+
|
| 106 |
+
return {
|
| 107 |
+
'image': image,
|
| 108 |
+
'alpha': alpha,
|
| 109 |
+
'extrinsics': extrinsics,
|
| 110 |
+
'intrinsics': intrinsics,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def _get_feat(self, root, instance):
|
| 114 |
+
DATA_RESOLUTION = 64
|
| 115 |
+
feats_path = os.path.join(root, 'features', self.model, f'{instance}.npz')
|
| 116 |
+
feats_data = np.load(feats_path, allow_pickle=True)
|
| 117 |
+
coords = torch.tensor(feats_data['indices']).int()
|
| 118 |
+
feats = torch.tensor(feats_data['patchtokens']).float()
|
| 119 |
+
|
| 120 |
+
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}_skeleton.ply'))[0]
|
| 121 |
+
coords_skl = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
|
| 122 |
+
ss_skl = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
|
| 123 |
+
ss_skl[0, coords_skl[:,0], coords_skl[:,1], coords_skl[:,2]] = 1
|
| 124 |
+
ss_skl_ori = ss_skl.clone()
|
| 125 |
+
if self.skl_dilation_random_aug or self.skl_dilation_iter > 0:
|
| 126 |
+
size = max(0, self.skl_dilation_iter) * 2 + 1
|
| 127 |
+
if self.skl_dilation_iter > 0:
|
| 128 |
+
kernel = torch.ones(1, 1, size, size, size, dtype=torch.float32, device=ss_skl.device)
|
| 129 |
+
ss_skl = torch.nn.functional.conv3d(ss_skl.float()[None], kernel, padding=self.skl_dilation_iter)
|
| 130 |
+
ss_skl = (ss_skl > 0).long().squeeze(0)
|
| 131 |
+
coords_skl = torch.nonzero(ss_skl[0], as_tuple=False).int()
|
| 132 |
+
if self.skl_dilation_random_aug and np.random.rand() < self.skl_dilation_random_aug_prob:
|
| 133 |
+
size_small, size_large = size - 2, size + 2
|
| 134 |
+
kernel_large = torch.ones(1, 1, size_large, size_large, size_large, dtype=torch.float32, device=ss_skl_ori.device)
|
| 135 |
+
ss_skl_large = torch.nn.functional.conv3d(ss_skl_ori.float()[None], kernel_large, padding=size_large//2)
|
| 136 |
+
ss_skl_large = (ss_skl_large > 0).long().squeeze(0)
|
| 137 |
+
if size_small > 1:
|
| 138 |
+
kernel_small = torch.ones(1, 1, size_small, size_small, size_small, dtype=torch.float32, device=ss_skl_ori.device)
|
| 139 |
+
ss_skl_small = torch.nn.functional.conv3d(ss_skl_ori.float()[None], kernel_small, padding=size_small//2)
|
| 140 |
+
ss_skl_small = (ss_skl_small > 0).long().squeeze(0)
|
| 141 |
+
else:
|
| 142 |
+
ss_skl_small = torch.zeros_like(ss_skl)
|
| 143 |
+
|
| 144 |
+
ss_skl_random_mask = torch.rand_like(ss_skl.float()) < 0.5
|
| 145 |
+
ss_skl = ss_skl_small * ss_skl_random_mask.long() + ss_skl_large * (1 - ss_skl_random_mask.long())
|
| 146 |
+
coords_skl = torch.nonzero(ss_skl[0], as_tuple=False).int()
|
| 147 |
+
feats_skl = torch.zeros((coords_skl.shape[0], 0), dtype=torch.float32)
|
| 148 |
+
|
| 149 |
+
if self.resolution != DATA_RESOLUTION:
|
| 150 |
+
factor = DATA_RESOLUTION // self.resolution
|
| 151 |
+
coords = coords // factor
|
| 152 |
+
coords, idx = coords.unique(return_inverse=True, dim=0)
|
| 153 |
+
feats = torch.scatter_reduce(
|
| 154 |
+
torch.zeros(coords.shape[0], feats.shape[1], device=feats.device),
|
| 155 |
+
dim=0,
|
| 156 |
+
index=idx.unsqueeze(-1).expand(-1, feats.shape[1]),
|
| 157 |
+
src=feats,
|
| 158 |
+
reduce='mean'
|
| 159 |
+
)
|
| 160 |
+
coords_skl = coords_skl // factor
|
| 161 |
+
coords_skl, idx = coords_skl.unique(return_inverse=True, dim=0)
|
| 162 |
+
feats_skl = torch.scatter_reduce(
|
| 163 |
+
torch.zeros(coords_skl.shape[0], feats_skl.shape[1], device=feats_skl.device),
|
| 164 |
+
dim=0,
|
| 165 |
+
index=idx.unsqueeze(-1).expand(-1, feats_skl.shape[1]),
|
| 166 |
+
src=feats_skl,
|
| 167 |
+
reduce='mean'
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
'coords': coords,
|
| 172 |
+
'feats': feats,
|
| 173 |
+
'coords_skl': coords_skl,
|
| 174 |
+
'feats_skl': feats_skl,
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
@torch.no_grad()
|
| 178 |
+
def visualize_sample(self, sample: dict):
|
| 179 |
+
return sample['image']
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def collate_fn(batch):
|
| 183 |
+
pack = {}
|
| 184 |
+
coords = []
|
| 185 |
+
coords_skl = []
|
| 186 |
+
for i, b in enumerate(batch):
|
| 187 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
| 188 |
+
coords_skl.append(torch.cat([torch.full((b['coords_skl'].shape[0], 1), i, dtype=torch.int32), b['coords_skl']], dim=-1))
|
| 189 |
+
coords = torch.cat(coords)
|
| 190 |
+
feats = torch.cat([b['feats'] for b in batch])
|
| 191 |
+
pack['feats'] = SparseTensor(
|
| 192 |
+
coords=coords,
|
| 193 |
+
feats=feats,
|
| 194 |
+
)
|
| 195 |
+
coords_skl = torch.cat(coords_skl)
|
| 196 |
+
feats_skl = torch.cat([b['feats_skl'] for b in batch])
|
| 197 |
+
pack['feats_skl'] = SparseTensor(
|
| 198 |
+
coords=coords_skl,
|
| 199 |
+
feats=feats_skl,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
pack['image'] = torch.stack([b['image'] for b in batch])
|
| 203 |
+
pack['alpha'] = torch.stack([b['alpha'] for b in batch])
|
| 204 |
+
pack['extrinsics'] = torch.stack([b['extrinsics'] for b in batch])
|
| 205 |
+
pack['intrinsics'] = torch.stack([b['intrinsics'] for b in batch])
|
| 206 |
+
|
| 207 |
+
pack['joints'] = [b['joints'] for b in batch]
|
| 208 |
+
pack['parents'] = [b['parents'] for b in batch]
|
| 209 |
+
pack['skin'] = [b['skin'] for b in batch]
|
| 210 |
+
pack['is_bad_skin'] = [b['is_bad_skin'] for b in batch]
|
| 211 |
+
|
| 212 |
+
# collate other data
|
| 213 |
+
keys = [k for k in batch[0].keys() if k not in ['coords', 'feats', 'coords_skl', 'feats_skl', 'image', 'alpha', 'extrinsics', 'intrinsics', 'joints', 'parents', 'skin']]
|
| 214 |
+
for k in keys:
|
| 215 |
+
if isinstance(batch[0][k], torch.Tensor):
|
| 216 |
+
pack[k] = torch.stack([b[k] for b in batch])
|
| 217 |
+
elif isinstance(batch[0][k], list):
|
| 218 |
+
pack[k] = sum([b[k] for b in batch], [])
|
| 219 |
+
else:
|
| 220 |
+
pack[k] = [b[k] for b in batch]
|
| 221 |
+
|
| 222 |
+
return pack
|
| 223 |
+
|
| 224 |
+
def _get_geo(self, root, instance):
|
| 225 |
+
skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
|
| 226 |
+
skl_data = np.load(skeleton_path, allow_pickle=True)
|
| 227 |
+
verts, face = np.array(skl_data['vertices'], dtype=np.float32), skl_data['faces']
|
| 228 |
+
mesh = {
|
| 229 |
+
"vertices" : torch.from_numpy(verts),
|
| 230 |
+
"faces" : torch.from_numpy(face),
|
| 231 |
+
}
|
| 232 |
+
geo = {"mesh": mesh}
|
| 233 |
+
if self.load_cubvh:
|
| 234 |
+
from cubvh import cuBVH
|
| 235 |
+
torch.cuda.set_device(self.local_rank)
|
| 236 |
+
cubvh_path = os.path.join(root, 'skeleton', instance, 'cubvh.pth')
|
| 237 |
+
if os.path.exists(cubvh_path):
|
| 238 |
+
bvh = torch.load(cubvh_path, weights_only=False)
|
| 239 |
+
if isinstance(bvh, cuBVH):
|
| 240 |
+
bvh = bvh.to('cpu')
|
| 241 |
+
else:
|
| 242 |
+
device = torch.device(f"cuda:{self.local_rank}")
|
| 243 |
+
bvh = cuBVH(mesh["vertices"], mesh["faces"], device=device)
|
| 244 |
+
bvh = bvh.to('cpu')
|
| 245 |
+
torch.save(bvh, cubvh_path)
|
| 246 |
+
geo["cubvh"] = bvh
|
| 247 |
+
return geo
|
| 248 |
+
|
| 249 |
+
def _get_skeleton(self, root, instance):
|
| 250 |
+
skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
|
| 251 |
+
skl_data = np.load(skeleton_path, allow_pickle=True)
|
| 252 |
+
joints, parents, skin = skl_data['joints'], skl_data['parents'], skl_data['skin']
|
| 253 |
+
parents[parents==None] = -1
|
| 254 |
+
parents = np.array(parents, dtype=np.int32)
|
| 255 |
+
|
| 256 |
+
skin[np.where(skl_data['skin'].max(axis=1)==0)[0], 0] = 1.0
|
| 257 |
+
skin = skin / skin.sum(-1, keepdims=True)
|
| 258 |
+
|
| 259 |
+
if self.skin_accum_as_flow:
|
| 260 |
+
root_idx = np.where(parents == -1)[0][0]
|
| 261 |
+
def sum_children(joint_idx, skin_weights):
|
| 262 |
+
children = np.where(parents == joint_idx)[0]
|
| 263 |
+
for child in children:
|
| 264 |
+
skin_weights[:, joint_idx] += sum_children(child, skin_weights)
|
| 265 |
+
return skin_weights[:, joint_idx]
|
| 266 |
+
sum_children(root_idx, skin)
|
| 267 |
+
skin = np.clip(skin, 0, 1)
|
| 268 |
+
|
| 269 |
+
is_bad_skin = self.metadata['is_bad_skin'][instance]
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
'joints': torch.from_numpy(joints).float(),
|
| 273 |
+
'parents': torch.from_numpy(parents).int(),
|
| 274 |
+
'skin': torch.from_numpy(skin).float(),
|
| 275 |
+
'is_bad_skin': is_bad_skin
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
def get_instance(self, root, instance):
|
| 279 |
+
image = self._get_image(root, instance)
|
| 280 |
+
feat = self._get_feat(root, instance)
|
| 281 |
+
geo = self._get_geo(root, instance)
|
| 282 |
+
skl = self._get_skeleton(root, instance)
|
| 283 |
+
|
| 284 |
+
return {
|
| 285 |
+
**image,
|
| 286 |
+
**feat,
|
| 287 |
+
**geo,
|
| 288 |
+
**skl,
|
| 289 |
+
'instance': instance,
|
| 290 |
+
}
|
anigen/datasets/anigen_sparse_structure.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
import utils3d
|
| 9 |
+
from .components import StandardDatasetBase
|
| 10 |
+
from ..representations.octree import DfsOctree as Octree
|
| 11 |
+
from ..renderers import OctreeRenderer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AniGenSparseStructure(StandardDatasetBase):
|
| 15 |
+
"""
|
| 16 |
+
Sparse structure dataset
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
roots (str): path to the dataset
|
| 20 |
+
resolution (int): resolution of the voxel grid
|
| 21 |
+
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
roots,
|
| 26 |
+
resolution: int = 64,
|
| 27 |
+
min_aesthetic_score: float = 5.0,
|
| 28 |
+
skl_dilation_iter: int = 0,
|
| 29 |
+
**kwargs,
|
| 30 |
+
):
|
| 31 |
+
self.resolution = resolution
|
| 32 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 33 |
+
self.skl_dilation_iter = skl_dilation_iter
|
| 34 |
+
self.value_range = (0, 1)
|
| 35 |
+
|
| 36 |
+
super().__init__(roots)
|
| 37 |
+
|
| 38 |
+
def filter_metadata(self, metadata):
|
| 39 |
+
stats = {}
|
| 40 |
+
metadata = metadata[metadata[f'voxelized']]
|
| 41 |
+
stats['Voxelized'] = len(metadata)
|
| 42 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 43 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 44 |
+
return metadata, stats
|
| 45 |
+
|
| 46 |
+
def get_ply_instance(self, root, instance, dilation_iter=None):
|
| 47 |
+
if dilation_iter is None:
|
| 48 |
+
dilation_iter = self.skl_dilation_iter
|
| 49 |
+
|
| 50 |
+
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
|
| 51 |
+
coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
|
| 52 |
+
ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
|
| 53 |
+
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
|
| 54 |
+
if dilation_iter > 0:
|
| 55 |
+
# 3D Dilation
|
| 56 |
+
size = dilation_iter * 2 + 1
|
| 57 |
+
kernel = torch.ones(1, 1, size, size, size, dtype=torch.float32, device=ss.device)
|
| 58 |
+
ss = torch.nn.functional.conv3d(ss.float()[None], kernel, padding=dilation_iter)
|
| 59 |
+
ss = (ss > 0).long().squeeze(0)
|
| 60 |
+
return ss
|
| 61 |
+
|
| 62 |
+
def get_instance(self, root, instance):
|
| 63 |
+
ss = self.get_ply_instance(root, instance, dilation_iter=0)
|
| 64 |
+
ss_skl = self.get_ply_instance(root, f'{instance}_skeleton', dilation_iter=self.skl_dilation_iter)
|
| 65 |
+
return {'ss': ss, 'ss_skl': ss_skl, 'instance': instance}
|
| 66 |
+
|
| 67 |
+
@torch.no_grad()
|
| 68 |
+
def visualize_sample(self, ss: Union[torch.Tensor, dict]):
|
| 69 |
+
ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
|
| 70 |
+
|
| 71 |
+
renderer = OctreeRenderer()
|
| 72 |
+
renderer.rendering_options.resolution = 512
|
| 73 |
+
renderer.rendering_options.near = 0.8
|
| 74 |
+
renderer.rendering_options.far = 1.6
|
| 75 |
+
renderer.rendering_options.bg_color = (0, 0, 0)
|
| 76 |
+
renderer.rendering_options.ssaa = 4
|
| 77 |
+
renderer.pipe.primitive = 'voxel'
|
| 78 |
+
|
| 79 |
+
# Build camera
|
| 80 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
| 81 |
+
yaws_offset = 0. # np.random.uniform(-np.pi / 4, np.pi / 4)
|
| 82 |
+
yaws = [y + yaws_offset for y in yaws]
|
| 83 |
+
pitch = np.linspace(-np.pi / 4, np.pi / 4, num=4) # [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
| 84 |
+
|
| 85 |
+
exts = []
|
| 86 |
+
ints = []
|
| 87 |
+
for yaw, pitch in zip(yaws, pitch):
|
| 88 |
+
orig = torch.tensor([
|
| 89 |
+
np.sin(yaw) * np.cos(pitch),
|
| 90 |
+
np.cos(yaw) * np.cos(pitch),
|
| 91 |
+
np.sin(pitch),
|
| 92 |
+
]).float().cuda() * 2
|
| 93 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
| 94 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 95 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 96 |
+
exts.append(extrinsics)
|
| 97 |
+
ints.append(intrinsics)
|
| 98 |
+
|
| 99 |
+
images = []
|
| 100 |
+
|
| 101 |
+
# Build each representation
|
| 102 |
+
ss = ss.cuda()
|
| 103 |
+
for i in range(ss.shape[0]):
|
| 104 |
+
representation = Octree(
|
| 105 |
+
depth=10,
|
| 106 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
| 107 |
+
device='cuda',
|
| 108 |
+
primitive='voxel',
|
| 109 |
+
sh_degree=0,
|
| 110 |
+
primitive_config={'solid': True},
|
| 111 |
+
)
|
| 112 |
+
coords = torch.nonzero(ss[i, 0], as_tuple=False)
|
| 113 |
+
representation.position = coords.float() / self.resolution
|
| 114 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
|
| 115 |
+
|
| 116 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
| 117 |
+
tile = [2, 2]
|
| 118 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
| 119 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
| 120 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
| 121 |
+
images.append(image)
|
| 122 |
+
|
| 123 |
+
return torch.stack(images)
|
| 124 |
+
|
anigen/datasets/anigen_sparse_structure_latent.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import *
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import utils3d
|
| 7 |
+
from ..representations.octree import DfsOctree as Octree
|
| 8 |
+
from ..renderers import OctreeRenderer
|
| 9 |
+
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
|
| 10 |
+
from .. import models
|
| 11 |
+
from ..utils.dist_utils import read_file_dist
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AniGenSparseStructureLatentVisMixin:
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
*args,
|
| 19 |
+
pretrained_ss_dec: str = None,
|
| 20 |
+
ss_dec_path: Optional[str] = '',
|
| 21 |
+
ss_dec_ckpt: Optional[str] = 'final',
|
| 22 |
+
**kwargs
|
| 23 |
+
):
|
| 24 |
+
super().__init__(*args, **kwargs)
|
| 25 |
+
self.ss_dec = None
|
| 26 |
+
self.pretrained_ss_dec = pretrained_ss_dec
|
| 27 |
+
self.ss_dec_path = ss_dec_path
|
| 28 |
+
self.ss_dec_ckpt = ss_dec_ckpt
|
| 29 |
+
|
| 30 |
+
def _loading_ss_dec(self):
|
| 31 |
+
if self.ss_dec is not None:
|
| 32 |
+
return
|
| 33 |
+
if self.ss_dec_path is not None:
|
| 34 |
+
cfg = json.load(open(os.path.join(self.ss_dec_path, 'config.json'), 'r'))
|
| 35 |
+
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
| 36 |
+
ckpt_path = os.path.join(self.ss_dec_path, 'ckpts', f'decoder_{self.ss_dec_ckpt}.pt')
|
| 37 |
+
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
| 38 |
+
# decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True)) # Got stuck...
|
| 39 |
+
else:
|
| 40 |
+
decoder = models.from_pretrained(self.pretrained_ss_dec)
|
| 41 |
+
self.ss_dec = decoder.cuda().eval()
|
| 42 |
+
|
| 43 |
+
def _delete_ss_dec(self):
|
| 44 |
+
del self.ss_dec
|
| 45 |
+
self.ss_dec = None
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def decode_latent(self, z, z_skl, batch_size=4):
|
| 49 |
+
self._loading_ss_dec()
|
| 50 |
+
ss = []
|
| 51 |
+
ss_skl = []
|
| 52 |
+
if self.normalization is not None:
|
| 53 |
+
z = z * self.std.to(z.device) + self.mean.to(z.device)
|
| 54 |
+
if self.normalization_skl is not None:
|
| 55 |
+
z_skl = z_skl * self.std_skl.to(z_skl.device) + self.mean_skl.to(z_skl.device)
|
| 56 |
+
for i in range(0, z.shape[0], batch_size):
|
| 57 |
+
z_, z_skl_ = z[i:i+batch_size], z_skl[i:i+batch_size]
|
| 58 |
+
ss_, ss_skl_ = self.ss_dec(z_, z_skl_)
|
| 59 |
+
ss.append(ss_)
|
| 60 |
+
ss_skl.append(ss_skl_)
|
| 61 |
+
ss = torch.cat(ss, dim=0)
|
| 62 |
+
ss_skl = torch.cat(ss_skl, dim=0)
|
| 63 |
+
self._delete_ss_dec()
|
| 64 |
+
return ss, ss_skl
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def visualize_sample(self, x_0: Union[torch.Tensor, dict], x_0_skl: Optional[Union[torch.Tensor, dict]]=None, **kwargs):
|
| 68 |
+
|
| 69 |
+
x_0_skl = x_0_skl if isinstance(x_0, torch.Tensor) else x_0['x_0_skl']
|
| 70 |
+
x_0 = x_0 if isinstance(x_0, torch.Tensor) else x_0['x_0']
|
| 71 |
+
x_0, x_0_skl = self.decode_latent(x_0.cuda(), x_0_skl.cuda())
|
| 72 |
+
|
| 73 |
+
renderer = OctreeRenderer()
|
| 74 |
+
renderer.rendering_options.resolution = 512
|
| 75 |
+
renderer.rendering_options.near = 0.8
|
| 76 |
+
renderer.rendering_options.far = 1.6
|
| 77 |
+
renderer.rendering_options.bg_color = (0, 0, 0)
|
| 78 |
+
renderer.rendering_options.ssaa = 4
|
| 79 |
+
renderer.pipe.primitive = 'voxel'
|
| 80 |
+
|
| 81 |
+
# Build camera
|
| 82 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
| 83 |
+
yaws_offset = 0 # np.random.uniform(-np.pi / 4, np.pi / 4)
|
| 84 |
+
yaws = [y + yaws_offset for y in yaws]
|
| 85 |
+
pitch = np.linspace(-np.pi / 4, np.pi / 4, 4) # [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
| 86 |
+
|
| 87 |
+
exts = []
|
| 88 |
+
ints = []
|
| 89 |
+
for yaw, pitch in zip(yaws, pitch):
|
| 90 |
+
orig = torch.tensor([
|
| 91 |
+
np.sin(yaw) * np.cos(pitch),
|
| 92 |
+
np.cos(yaw) * np.cos(pitch),
|
| 93 |
+
np.sin(pitch),
|
| 94 |
+
]).float().cuda() * 2
|
| 95 |
+
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
| 96 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 97 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 98 |
+
exts.append(extrinsics)
|
| 99 |
+
ints.append(intrinsics)
|
| 100 |
+
|
| 101 |
+
images = []
|
| 102 |
+
x_0 = x_0.cuda()
|
| 103 |
+
for i in range(x_0.shape[0]):
|
| 104 |
+
representation = Octree(
|
| 105 |
+
depth=10,
|
| 106 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
| 107 |
+
device='cuda',
|
| 108 |
+
primitive='voxel',
|
| 109 |
+
sh_degree=0,
|
| 110 |
+
primitive_config={'solid': True},
|
| 111 |
+
)
|
| 112 |
+
coords = torch.nonzero(x_0[i, 0] > 0, as_tuple=False)
|
| 113 |
+
resolution = x_0.shape[-1]
|
| 114 |
+
representation.position = coords.float() / resolution
|
| 115 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
|
| 116 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
| 117 |
+
tile = [2, 2]
|
| 118 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
| 119 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
| 120 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
| 121 |
+
images.append(image)
|
| 122 |
+
|
| 123 |
+
x_0_skl = x_0_skl.cuda()
|
| 124 |
+
for i in range(x_0_skl.shape[0]):
|
| 125 |
+
representation = Octree(
|
| 126 |
+
depth=10,
|
| 127 |
+
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
| 128 |
+
device='cuda',
|
| 129 |
+
primitive='voxel',
|
| 130 |
+
sh_degree=0,
|
| 131 |
+
primitive_config={'solid': True},
|
| 132 |
+
)
|
| 133 |
+
coords = torch.nonzero(x_0_skl[i, 0] > 0, as_tuple=False)
|
| 134 |
+
resolution = x_0_skl.shape[-1]
|
| 135 |
+
representation.position = coords.float() / resolution
|
| 136 |
+
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(resolution)), dtype=torch.uint8, device='cuda')
|
| 137 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
| 138 |
+
tile = [2, 2]
|
| 139 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
| 140 |
+
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
|
| 141 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
| 142 |
+
images[i] = torch.cat([images[i], image], dim=2)
|
| 143 |
+
|
| 144 |
+
return torch.stack(images)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class AniGenSparseStructureLatent(AniGenSparseStructureLatentVisMixin, StandardDatasetBase):
|
| 148 |
+
"""
|
| 149 |
+
Sparse structure latent dataset
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
roots (str): path to the dataset
|
| 153 |
+
latent_model (str): name of the latent model
|
| 154 |
+
min_aesthetic_score (float): minimum aesthetic score
|
| 155 |
+
normalization (dict): normalization stats
|
| 156 |
+
pretrained_ss_dec (str): name of the pretrained sparse structure decoder
|
| 157 |
+
ss_dec_path (str): path to the sparse structure decoder, if given, will override the pretrained_ss_dec
|
| 158 |
+
ss_dec_ckpt (str): name of the sparse structure decoder checkpoint
|
| 159 |
+
"""
|
| 160 |
+
def __init__(self,
|
| 161 |
+
roots: str,
|
| 162 |
+
*,
|
| 163 |
+
latent_model: str,
|
| 164 |
+
min_aesthetic_score: float = 5.0,
|
| 165 |
+
normalization: Optional[dict] = None,
|
| 166 |
+
normalization_skl: Optional[dict] = None,
|
| 167 |
+
pretrained_ss_dec: str = None,
|
| 168 |
+
ss_dec_path: Optional[str] = '',
|
| 169 |
+
ss_dec_ckpt: Optional[str] = 'final',
|
| 170 |
+
**kwargs,
|
| 171 |
+
):
|
| 172 |
+
self.latent_model = latent_model
|
| 173 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 174 |
+
self.normalization = normalization
|
| 175 |
+
self.normalization_skl = normalization_skl
|
| 176 |
+
self.value_range = (0, 1)
|
| 177 |
+
|
| 178 |
+
super().__init__(
|
| 179 |
+
roots,
|
| 180 |
+
pretrained_ss_dec=pretrained_ss_dec,
|
| 181 |
+
ss_dec_path=ss_dec_path,
|
| 182 |
+
ss_dec_ckpt=ss_dec_ckpt,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if self.normalization is not None:
|
| 187 |
+
data = np.load(self.normalization)
|
| 188 |
+
self.mean = torch.tensor(data['feats_mean'])
|
| 189 |
+
self.std = torch.tensor(data['feats_std'])
|
| 190 |
+
if self.normalization_skl is not None:
|
| 191 |
+
data = np.load(self.normalization_skl)
|
| 192 |
+
self.mean_skl = torch.tensor(data['feats_skl_mean'])
|
| 193 |
+
self.std_skl = torch.tensor(data['feats_skl_std']).clip(min=1e-3)
|
| 194 |
+
|
| 195 |
+
def filter_metadata(self, metadata):
|
| 196 |
+
stats = {}
|
| 197 |
+
metadata = metadata[metadata[f'ss_latent_{self.latent_model}']]
|
| 198 |
+
stats['With sparse structure latents'] = len(metadata)
|
| 199 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 200 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 201 |
+
|
| 202 |
+
if 'is_bad_skeleton' in metadata.columns:
|
| 203 |
+
metadata = metadata[~metadata['is_bad_skeleton']]
|
| 204 |
+
if 'is_bad_skin' in metadata.columns:
|
| 205 |
+
metadata = metadata[~metadata['is_bad_skin']]
|
| 206 |
+
|
| 207 |
+
return metadata, stats
|
| 208 |
+
|
| 209 |
+
def get_instance(self, root, instance):
|
| 210 |
+
latent = np.load(os.path.join(root, 'ss_latents', self.latent_model, f'{instance}.npz'))
|
| 211 |
+
z = torch.tensor(latent['mean']).float()
|
| 212 |
+
z_skl = torch.tensor(latent['mean_skl']).float()
|
| 213 |
+
if self.normalization is not None:
|
| 214 |
+
z = (z - self.mean) / self.std
|
| 215 |
+
if self.normalization_skl is not None:
|
| 216 |
+
z_skl = (z_skl - self.mean_skl) / self.std_skl
|
| 217 |
+
|
| 218 |
+
pack = {
|
| 219 |
+
'instance': instance,
|
| 220 |
+
'x_0': z,
|
| 221 |
+
'x_0_skl': z_skl,
|
| 222 |
+
}
|
| 223 |
+
return pack
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class TextConditionedAniGenSparseStructureLatent(TextConditionedMixin, AniGenSparseStructureLatent):
|
| 227 |
+
"""
|
| 228 |
+
Text-conditioned sparse structure dataset
|
| 229 |
+
"""
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ImageConditionedAniGenSparseStructureLatent(ImageConditionedMixin, AniGenSparseStructureLatent):
|
| 234 |
+
"""
|
| 235 |
+
Image-conditioned sparse structure dataset
|
| 236 |
+
"""
|
| 237 |
+
pass
|
| 238 |
+
|
anigen/datasets/anigen_structured_latent.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import *
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import utils3d.torch
|
| 7 |
+
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
|
| 8 |
+
from ..modules.sparse.basic import SparseTensor
|
| 9 |
+
from .. import models
|
| 10 |
+
from ..utils.render_utils import get_renderer
|
| 11 |
+
from ..utils.dist_utils import read_file_dist
|
| 12 |
+
from ..utils.data_utils import load_balanced_group_indices
|
| 13 |
+
import copy
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AniGenSLatVisMixin:
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
*args,
|
| 21 |
+
pretrained_slat_dec: str = None,
|
| 22 |
+
slat_dec_path: Optional[str] = None,
|
| 23 |
+
slat_dec_ckpt: Optional[str] = None,
|
| 24 |
+
load_cubvh: bool = False,
|
| 25 |
+
**kwargs
|
| 26 |
+
):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
self.slat_dec = None
|
| 29 |
+
self.pretrained_slat_dec = pretrained_slat_dec
|
| 30 |
+
self.slat_dec_path = slat_dec_path
|
| 31 |
+
self.slat_dec_ckpt = slat_dec_ckpt
|
| 32 |
+
self.load_cubvh = load_cubvh
|
| 33 |
+
|
| 34 |
+
def _loading_slat_dec(self):
|
| 35 |
+
if self.slat_dec is not None:
|
| 36 |
+
return
|
| 37 |
+
if self.slat_dec_path is not None:
|
| 38 |
+
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
|
| 39 |
+
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
| 40 |
+
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
|
| 41 |
+
# decoder.load_state_dict(torch.load(read_file_dist(ckpt_path), map_location='cpu', weights_only=True))
|
| 42 |
+
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
| 43 |
+
else:
|
| 44 |
+
decoder = models.from_pretrained(self.pretrained_slat_dec)
|
| 45 |
+
self.slat_dec = decoder.cuda().eval()
|
| 46 |
+
|
| 47 |
+
def _delete_slat_dec(self):
|
| 48 |
+
del self.slat_dec
|
| 49 |
+
self.slat_dec = None
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def decode_latent(self, z, z_skl, gt_joints=None, gt_parents=None, batch_size=4, gt_reps=None, gt_reps_skl=None):
|
| 53 |
+
self._loading_slat_dec()
|
| 54 |
+
reps = []
|
| 55 |
+
reps_skl = []
|
| 56 |
+
if gt_reps is not None:
|
| 57 |
+
skins_gt_ssskin = []
|
| 58 |
+
if gt_reps_skl is not None:
|
| 59 |
+
skins_gt_sklskin = []
|
| 60 |
+
if self.normalization is not None:
|
| 61 |
+
z = z * self.std.to(z.device) + self.mean.to(z.device)
|
| 62 |
+
z_skl = z_skl * self.std_skl.to(z.device) + self.mean_skl.to(z.device)
|
| 63 |
+
for i in range(0, z.shape[0], batch_size):
|
| 64 |
+
gt_j, gt_p = None if gt_joints is None else gt_joints[i:i+batch_size], None if gt_parents is None else gt_parents[i:i+batch_size]
|
| 65 |
+
z_, z_skl_ = z[i:i+batch_size], z_skl[i:i+batch_size]
|
| 66 |
+
rep, rep_skl = self.slat_dec(z_, z_skl_, gt_joints=gt_j, gt_parents=gt_p)
|
| 67 |
+
reps.append(rep)
|
| 68 |
+
reps_skl.append(rep_skl)
|
| 69 |
+
if gt_reps is not None:
|
| 70 |
+
skins_gt_ssskin.append(self.slat_dec.skinweight_forward(gt_reps[i:i+batch_size], rep_skl, gt_joints=gt_j, gt_parents=gt_p, return_skin_pred_only=True))
|
| 71 |
+
if gt_reps_skl is not None:
|
| 72 |
+
skins_gt_sklskin.append(self.slat_dec.skinweight_forward(rep, gt_reps_skl[i:i+batch_size], gt_joints=gt_j, gt_parents=gt_p, return_skin_pred_only=True))
|
| 73 |
+
reps = sum(reps, [])
|
| 74 |
+
reps_skl = sum(reps_skl, [])
|
| 75 |
+
self._delete_slat_dec()
|
| 76 |
+
to_return = (reps, reps_skl)
|
| 77 |
+
if gt_reps is not None:
|
| 78 |
+
skins_gt_ssskin = sum(skins_gt_ssskin, [])
|
| 79 |
+
to_return += (skins_gt_ssskin,)
|
| 80 |
+
if gt_reps_skl is not None:
|
| 81 |
+
skins_gt_sklskin = sum(skins_gt_sklskin, [])
|
| 82 |
+
to_return += (skins_gt_sklskin,)
|
| 83 |
+
return to_return
|
| 84 |
+
|
| 85 |
+
class AniGenSLat(AniGenSLatVisMixin, StandardDatasetBase):
|
| 86 |
+
"""
|
| 87 |
+
structured latent dataset
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
roots (str): path to the dataset
|
| 91 |
+
latent_model (str): name of the latent model
|
| 92 |
+
min_aesthetic_score (float): minimum aesthetic score
|
| 93 |
+
max_num_voxels (int): maximum number of voxels
|
| 94 |
+
normalization (dict): normalization stats
|
| 95 |
+
pretrained_slat_dec (str): name of the pretrained slat decoder
|
| 96 |
+
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
|
| 97 |
+
slat_dec_ckpt (str): name of the slat decoder checkpoint
|
| 98 |
+
"""
|
| 99 |
+
def __init__(self,
|
| 100 |
+
roots: str,
|
| 101 |
+
*,
|
| 102 |
+
latent_model: str,
|
| 103 |
+
use_joint_num_cond: bool = False,
|
| 104 |
+
min_aesthetic_score: float = 5.0,
|
| 105 |
+
max_num_voxels: int = 32768,
|
| 106 |
+
normalization: Optional[dict] = None,
|
| 107 |
+
pretrained_slat_dec: str = None,
|
| 108 |
+
slat_dec_path: Optional[str] = None,
|
| 109 |
+
slat_dec_ckpt: Optional[str] = None,
|
| 110 |
+
local_rank: int = 0,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
self.normalization = normalization
|
| 114 |
+
self.latent_model = latent_model
|
| 115 |
+
self.use_joint_num_cond = use_joint_num_cond
|
| 116 |
+
self.min_aesthetic_score = min_aesthetic_score
|
| 117 |
+
self.max_num_voxels = max_num_voxels
|
| 118 |
+
self.value_range = (0, 1)
|
| 119 |
+
self.local_rank = local_rank
|
| 120 |
+
|
| 121 |
+
super().__init__(
|
| 122 |
+
roots,
|
| 123 |
+
pretrained_slat_dec=pretrained_slat_dec,
|
| 124 |
+
slat_dec_path=slat_dec_path,
|
| 125 |
+
slat_dec_ckpt=slat_dec_ckpt,
|
| 126 |
+
**kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
|
| 130 |
+
|
| 131 |
+
if self.normalization is not None:
|
| 132 |
+
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
|
| 133 |
+
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
|
| 134 |
+
self.mean_skl = torch.tensor(self.normalization['mean_skl']).reshape(1, -1)
|
| 135 |
+
self.std_skl = torch.tensor(self.normalization['std_skl']).reshape(1, -1)
|
| 136 |
+
|
| 137 |
+
def filter_metadata(self, metadata):
|
| 138 |
+
stats = {}
|
| 139 |
+
metadata = metadata[metadata[f'latent_{self.latent_model}']]
|
| 140 |
+
stats['With latent'] = len(metadata)
|
| 141 |
+
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
| 142 |
+
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
| 143 |
+
# metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
| 144 |
+
# stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
| 145 |
+
|
| 146 |
+
if 'is_bad_skeleton' in metadata.columns:
|
| 147 |
+
metadata = metadata[~metadata['is_bad_skeleton']]
|
| 148 |
+
if 'is_bad_skin' in metadata.columns:
|
| 149 |
+
metadata = metadata[~metadata['is_bad_skin']]
|
| 150 |
+
|
| 151 |
+
return metadata, stats
|
| 152 |
+
|
| 153 |
+
@torch.no_grad()
|
| 154 |
+
def visualize_sample(self, data: dict):
|
| 155 |
+
return {}
|
| 156 |
+
x_0 = data['x_0']
|
| 157 |
+
x_0_skl = data['x_0_skl']
|
| 158 |
+
reps, reps_skl = self.decode_latent(x_0.cuda(), x_0_skl.cuda(), data['joints'])
|
| 159 |
+
|
| 160 |
+
# Build camera
|
| 161 |
+
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
| 162 |
+
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
| 163 |
+
yaws = [y + yaws_offset for y in yaws]
|
| 164 |
+
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
| 165 |
+
|
| 166 |
+
exts = []
|
| 167 |
+
ints = []
|
| 168 |
+
for yaw, pitch in zip(yaws, pitch):
|
| 169 |
+
orig = torch.tensor([
|
| 170 |
+
np.sin(yaw) * np.cos(pitch),
|
| 171 |
+
np.cos(yaw) * np.cos(pitch),
|
| 172 |
+
np.sin(pitch),
|
| 173 |
+
]).float().cuda() * 2
|
| 174 |
+
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
| 175 |
+
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
| 176 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
| 177 |
+
exts.append(extrinsics)
|
| 178 |
+
ints.append(intrinsics)
|
| 179 |
+
|
| 180 |
+
renderer = get_renderer(reps[0])
|
| 181 |
+
images = []
|
| 182 |
+
for representation in reps:
|
| 183 |
+
image = torch.zeros(3, 1024, 1024).cuda()
|
| 184 |
+
tile = [2, 2]
|
| 185 |
+
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
| 186 |
+
res = renderer.render(representation, ext, intr)
|
| 187 |
+
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
| 188 |
+
images.append(image)
|
| 189 |
+
images = torch.stack(images)
|
| 190 |
+
|
| 191 |
+
return images
|
| 192 |
+
|
| 193 |
+
def _get_skeleton(self, root, instance):
|
| 194 |
+
skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
|
| 195 |
+
skl_data = np.load(skeleton_path, allow_pickle=True)
|
| 196 |
+
joints, parents, skin = skl_data['joints'], skl_data['parents'], skl_data['skin']
|
| 197 |
+
parents[parents==None] = -1
|
| 198 |
+
parents = np.array(parents, dtype=np.int32)
|
| 199 |
+
ret = {
|
| 200 |
+
'joints': torch.from_numpy(joints).float(),
|
| 201 |
+
'parents': torch.from_numpy(parents).int(),
|
| 202 |
+
'skin': torch.from_numpy(skin).float(),
|
| 203 |
+
}
|
| 204 |
+
if self.use_joint_num_cond:
|
| 205 |
+
ret['joints_num'] = int(joints.shape[0])
|
| 206 |
+
return ret
|
| 207 |
+
|
| 208 |
+
def _get_geo(self, root, instance):
|
| 209 |
+
skeleton_path = os.path.join(root, 'skeleton', instance, 'skeleton_voxelized.npz')
|
| 210 |
+
skl_data = np.load(skeleton_path, allow_pickle=True)
|
| 211 |
+
verts, face = np.array(skl_data['vertices'], dtype=np.float32), skl_data['faces']
|
| 212 |
+
mesh = {
|
| 213 |
+
"vertices" : torch.from_numpy(verts),
|
| 214 |
+
"faces" : torch.from_numpy(face),
|
| 215 |
+
}
|
| 216 |
+
geo = {"mesh": mesh}
|
| 217 |
+
if self.load_cubvh:
|
| 218 |
+
from cubvh import cuBVH
|
| 219 |
+
torch.cuda.set_device(self.local_rank)
|
| 220 |
+
cubvh_path = os.path.join(root, 'skeleton', instance, 'cubvh.pth')
|
| 221 |
+
if os.path.exists(cubvh_path):
|
| 222 |
+
cubvh = torch.load(cubvh_path, weights_only=False)
|
| 223 |
+
else:
|
| 224 |
+
cubvh = cuBVH(mesh["vertices"], mesh["faces"])
|
| 225 |
+
torch.save(cubvh, cubvh_path)
|
| 226 |
+
geo["cubvh"] = cubvh
|
| 227 |
+
return geo
|
| 228 |
+
|
| 229 |
+
def get_instance(self, root, instance):
|
| 230 |
+
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
|
| 231 |
+
coords = torch.tensor(data['coords']).int()
|
| 232 |
+
feats = torch.tensor(data['feats']).float()
|
| 233 |
+
coords_skl = torch.tensor(data['coords_skl']).int()
|
| 234 |
+
feats_skl = torch.tensor(data['feats_skl']).float()
|
| 235 |
+
if self.normalization is not None:
|
| 236 |
+
feats = (feats - self.mean) / self.std
|
| 237 |
+
feats_skl = (feats_skl - self.mean_skl) / self.std_skl
|
| 238 |
+
return {
|
| 239 |
+
'coords': coords,
|
| 240 |
+
'feats': feats,
|
| 241 |
+
'coords_skl': coords_skl,
|
| 242 |
+
'feats_skl': feats_skl,
|
| 243 |
+
'instance': instance,
|
| 244 |
+
**self._get_skeleton(root, instance),
|
| 245 |
+
**self._get_geo(root, instance),
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def collate_fn(batch, split_size=None):
|
| 250 |
+
if split_size is None:
|
| 251 |
+
group_idx = [list(range(len(batch)))]
|
| 252 |
+
else:
|
| 253 |
+
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
|
| 254 |
+
packs = []
|
| 255 |
+
for group in group_idx:
|
| 256 |
+
sub_batch = [batch[i] for i in group]
|
| 257 |
+
pack = {}
|
| 258 |
+
coords = []
|
| 259 |
+
feats = []
|
| 260 |
+
coords_skl = []
|
| 261 |
+
feats_skl = []
|
| 262 |
+
layout = []
|
| 263 |
+
layout_skl = []
|
| 264 |
+
start = 0
|
| 265 |
+
start_skl = 0
|
| 266 |
+
for i, b in enumerate(sub_batch):
|
| 267 |
+
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
| 268 |
+
feats.append(b['feats'])
|
| 269 |
+
coords_skl.append(torch.cat([torch.full((b['coords_skl'].shape[0], 1), i, dtype=torch.int32), b['coords_skl']], dim=-1))
|
| 270 |
+
feats_skl.append(b['feats_skl'])
|
| 271 |
+
layout.append(slice(start, start + b['coords'].shape[0]))
|
| 272 |
+
layout_skl.append(slice(start_skl, start_skl + b['coords_skl'].shape[0]))
|
| 273 |
+
start += b['coords'].shape[0]
|
| 274 |
+
start_skl += b['coords_skl'].shape[0]
|
| 275 |
+
coords = torch.cat(coords)
|
| 276 |
+
feats = torch.cat(feats)
|
| 277 |
+
pack['x_0'] = SparseTensor(
|
| 278 |
+
coords=coords,
|
| 279 |
+
feats=feats,
|
| 280 |
+
)
|
| 281 |
+
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
|
| 282 |
+
pack['x_0'].register_spatial_cache('layout', layout)
|
| 283 |
+
|
| 284 |
+
coords_skl = torch.cat(coords_skl)
|
| 285 |
+
feats_skl = torch.cat(feats_skl)
|
| 286 |
+
pack['x_0_skl'] = SparseTensor(
|
| 287 |
+
coords=coords_skl,
|
| 288 |
+
feats=feats_skl,
|
| 289 |
+
)
|
| 290 |
+
pack['x_0_skl']._shape = torch.Size([len(group), *sub_batch[0]['feats_skl'].shape[1:]])
|
| 291 |
+
pack['x_0_skl'].register_spatial_cache('layout', layout_skl)
|
| 292 |
+
|
| 293 |
+
pack['joints'] = [b['joints'] for b in sub_batch]
|
| 294 |
+
pack['parents'] = [b['parents'] for b in sub_batch]
|
| 295 |
+
pack['skin'] = [b['skin'] for b in sub_batch]
|
| 296 |
+
if 'joints_num' in sub_batch[0]:
|
| 297 |
+
pack['joints_num'] = torch.tensor([b['joints_num'] for b in sub_batch], dtype=torch.long)
|
| 298 |
+
|
| 299 |
+
# collate other data
|
| 300 |
+
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats', 'coords_skl', 'feats_skl', 'joints', 'parents', 'skin', 'joints_num']]
|
| 301 |
+
for k in keys:
|
| 302 |
+
if isinstance(sub_batch[0][k], torch.Tensor):
|
| 303 |
+
pack[k] = torch.stack([b[k] for b in sub_batch])
|
| 304 |
+
elif isinstance(sub_batch[0][k], list):
|
| 305 |
+
pack[k] = sum([b[k] for b in sub_batch], [])
|
| 306 |
+
else:
|
| 307 |
+
pack[k] = [b[k] for b in sub_batch]
|
| 308 |
+
|
| 309 |
+
packs.append(pack)
|
| 310 |
+
|
| 311 |
+
if split_size is None:
|
| 312 |
+
return packs[0]
|
| 313 |
+
return packs
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class TextConditionedSLat(TextConditionedMixin, AniGenSLat):
|
| 317 |
+
"""
|
| 318 |
+
Text conditioned structured latent dataset
|
| 319 |
+
"""
|
| 320 |
+
pass
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class AniGenImageConditionedSLat(ImageConditionedMixin, AniGenSLat):
|
| 324 |
+
"""
|
| 325 |
+
Image conditioned structured latent dataset
|
| 326 |
+
"""
|
| 327 |
+
pass
|
anigen/datasets/components.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class StandardDatasetBase(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
Base class for standard datasets.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
roots (str): paths to the dataset
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
roots: str,
|
| 22 |
+
instances: List[str] = None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.roots = roots.split(',')
|
| 27 |
+
self.instances = []
|
| 28 |
+
self.metadata = pd.DataFrame()
|
| 29 |
+
|
| 30 |
+
self._stats = {}
|
| 31 |
+
for root in self.roots:
|
| 32 |
+
key = os.path.basename(root)
|
| 33 |
+
self._stats[key] = {}
|
| 34 |
+
metadata = pd.read_csv(os.path.join(root, 'metadata.csv'))
|
| 35 |
+
self._stats[key]['Total'] = len(metadata)
|
| 36 |
+
metadata, stats = self.filter_metadata(metadata)
|
| 37 |
+
self._stats[key].update(stats)
|
| 38 |
+
self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values])
|
| 39 |
+
metadata.set_index('sha256', inplace=True)
|
| 40 |
+
self.metadata = pd.concat([self.metadata, metadata])
|
| 41 |
+
|
| 42 |
+
if instances is not None:
|
| 43 |
+
self.test_mode = False
|
| 44 |
+
self.instances = [inst for inst in self.instances if inst[1] in instances]
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def get_instance(self, root: str, instance: str) -> Dict[str, Any]:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def __len__(self):
|
| 55 |
+
return len(self.instances)
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, index) -> Dict[str, Any]:
|
| 58 |
+
try:
|
| 59 |
+
root, instance = self.instances[index]
|
| 60 |
+
return self.get_instance(root, instance)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(e)
|
| 63 |
+
return self.__getitem__(np.random.randint(0, len(self)))
|
| 64 |
+
|
| 65 |
+
def __str__(self):
|
| 66 |
+
lines = []
|
| 67 |
+
lines.append(self.__class__.__name__)
|
| 68 |
+
lines.append(f' - Total instances: {len(self)}')
|
| 69 |
+
lines.append(f' - Sources:')
|
| 70 |
+
for key, stats in self._stats.items():
|
| 71 |
+
lines.append(f' - {key}:')
|
| 72 |
+
for k, v in stats.items():
|
| 73 |
+
lines.append(f' - {k}: {v}')
|
| 74 |
+
return '\n'.join(lines)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TextConditionedMixin:
|
| 78 |
+
def __init__(self, roots, **kwargs):
|
| 79 |
+
super().__init__(roots, **kwargs)
|
| 80 |
+
self.captions = {}
|
| 81 |
+
for instance in self.instances:
|
| 82 |
+
sha256 = instance[1]
|
| 83 |
+
self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions'])
|
| 84 |
+
|
| 85 |
+
def filter_metadata(self, metadata):
|
| 86 |
+
metadata, stats = super().filter_metadata(metadata)
|
| 87 |
+
metadata = metadata[metadata['captions'].notna()]
|
| 88 |
+
stats['With captions'] = len(metadata)
|
| 89 |
+
return metadata, stats
|
| 90 |
+
|
| 91 |
+
def get_instance(self, root, instance):
|
| 92 |
+
pack = super().get_instance(root, instance)
|
| 93 |
+
text = np.random.choice(self.captions[instance])
|
| 94 |
+
pack['cond'] = text
|
| 95 |
+
return pack
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ImageConditionedMixin:
|
| 99 |
+
def __init__(self, roots, *, image_size=518, **kwargs):
|
| 100 |
+
self.image_size = image_size
|
| 101 |
+
super().__init__(roots, **kwargs)
|
| 102 |
+
|
| 103 |
+
def filter_metadata(self, metadata):
|
| 104 |
+
metadata, stats = super().filter_metadata(metadata)
|
| 105 |
+
metadata = metadata[metadata[f'cond_rendered']]
|
| 106 |
+
stats['Cond rendered'] = len(metadata)
|
| 107 |
+
return metadata, stats
|
| 108 |
+
|
| 109 |
+
def get_instance(self, root, instance):
|
| 110 |
+
pack = super().get_instance(root, instance)
|
| 111 |
+
|
| 112 |
+
image_root = os.path.join(root, 'renders_cond', instance)
|
| 113 |
+
with open(os.path.join(image_root, 'transforms.json')) as f:
|
| 114 |
+
metadata = json.load(f)
|
| 115 |
+
n_views = len(metadata['frames'])
|
| 116 |
+
view = np.random.randint(n_views)
|
| 117 |
+
metadata = metadata['frames'][view]
|
| 118 |
+
|
| 119 |
+
image_path = os.path.join(image_root, metadata['file_path'])
|
| 120 |
+
image = Image.open(image_path)
|
| 121 |
+
|
| 122 |
+
alpha = np.array(image.getchannel(3))
|
| 123 |
+
bbox = np.array(alpha).nonzero()
|
| 124 |
+
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
|
| 125 |
+
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
|
| 126 |
+
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
|
| 127 |
+
aug_size_ratio = 1.2
|
| 128 |
+
aug_hsize = hsize * aug_size_ratio
|
| 129 |
+
aug_center_offset = [0, 0]
|
| 130 |
+
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
|
| 131 |
+
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
|
| 132 |
+
image = image.crop(aug_bbox)
|
| 133 |
+
|
| 134 |
+
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
| 135 |
+
alpha = image.getchannel(3)
|
| 136 |
+
image = image.convert('RGB')
|
| 137 |
+
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
| 138 |
+
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
| 139 |
+
image = image * alpha.unsqueeze(0)
|
| 140 |
+
pack['cond'] = image
|
| 141 |
+
|
| 142 |
+
return pack
|
| 143 |
+
|
anigen/models/__init__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
__attributes = {
|
| 4 |
+
'AniGenSparseStructureEncoder': 'anigen_sparse_structure_vae',
|
| 5 |
+
'AniGenSparseStructureDecoder': 'anigen_sparse_structure_vae',
|
| 6 |
+
'AniGenSparseStructureFlowModel': 'anigen_sparse_structure_flow',
|
| 7 |
+
'AniGenSparseStructureFlowModelInpaint': 'anigen_sparse_structure_flow_inpaint',
|
| 8 |
+
'AniGenElasticSLatEncoder': 'structured_latent_vae',
|
| 9 |
+
'AniGenElasticSLatMeshDecoder': 'structured_latent_vae',
|
| 10 |
+
'AniGenElasticSLatGaussianDecoder': 'structured_latent_vae',
|
| 11 |
+
'AniGenSLatFlowModel': 'anigen_structured_latent_flow',
|
| 12 |
+
'AniGenElasticSLatFlowModel': 'anigen_structured_latent_flow',
|
| 13 |
+
'AniGenElasticSLatFlowModelOld': 'anigen_structured_latent_flow_old',
|
| 14 |
+
'SkinAutoEncoder': 'structured_latent_vae',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
__submodules = []
|
| 18 |
+
|
| 19 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 20 |
+
|
| 21 |
+
def __getattr__(name):
|
| 22 |
+
if name not in globals():
|
| 23 |
+
if name in __attributes:
|
| 24 |
+
module_name = __attributes[name]
|
| 25 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 26 |
+
globals()[name] = getattr(module, name)
|
| 27 |
+
elif name in __submodules:
|
| 28 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 29 |
+
globals()[name] = module
|
| 30 |
+
else:
|
| 31 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 32 |
+
return globals()[name]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def from_pretrained(path: str, **kwargs):
|
| 36 |
+
"""
|
| 37 |
+
Load a model from a pretrained checkpoint.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
|
| 41 |
+
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
|
| 42 |
+
**kwargs: Additional arguments for the model constructor.
|
| 43 |
+
"""
|
| 44 |
+
import os
|
| 45 |
+
import json
|
| 46 |
+
from safetensors.torch import load_file
|
| 47 |
+
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
|
| 48 |
+
|
| 49 |
+
if is_local:
|
| 50 |
+
config_file = f"{path}.json"
|
| 51 |
+
model_file = f"{path}.safetensors"
|
| 52 |
+
else:
|
| 53 |
+
print(f"{path}.json and {path}.safetensors not found, trying to download from Hugging Face Hub.")
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
path_parts = path.split('/')
|
| 56 |
+
repo_id = f'{path_parts[0]}/{path_parts[1]}'
|
| 57 |
+
model_name = '/'.join(path_parts[2:])
|
| 58 |
+
config_file = hf_hub_download(repo_id, f"{model_name}.json")
|
| 59 |
+
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
|
| 60 |
+
|
| 61 |
+
with open(config_file, 'r') as f:
|
| 62 |
+
config = json.load(f)
|
| 63 |
+
model = __getattr__(config['name'])(**config['args'], **kwargs)
|
| 64 |
+
model.load_state_dict(load_file(model_file))
|
| 65 |
+
|
| 66 |
+
return model
|
| 67 |
+
|
anigen/models/anigen_sparse_structure_flow.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
from ..modules.utils import convert_module_to_f16, convert_module_to_f32
|
| 7 |
+
from ..modules.transformer import AbsolutePositionEmbedder, ModulatedTransformerCrossBlock
|
| 8 |
+
from ..modules.spatial import patchify, unpatchify
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimestepEmbedder(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Embeds scalar timesteps into vector representations.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.mlp = nn.Sequential(
|
| 18 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 19 |
+
nn.SiLU(),
|
| 20 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 21 |
+
)
|
| 22 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 26 |
+
"""
|
| 27 |
+
Create sinusoidal timestep embeddings.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
t: a 1-D Tensor of N indices, one per batch element.
|
| 31 |
+
These may be fractional.
|
| 32 |
+
dim: the dimension of the output.
|
| 33 |
+
max_period: controls the minimum frequency of the embeddings.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
an (N, D) Tensor of positional embeddings.
|
| 37 |
+
"""
|
| 38 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 39 |
+
half = dim // 2
|
| 40 |
+
freqs = torch.exp(
|
| 41 |
+
-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 42 |
+
).to(device=t.device)
|
| 43 |
+
args = t[:, None].float() * freqs[None]
|
| 44 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 45 |
+
if dim % 2:
|
| 46 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 47 |
+
return embedding
|
| 48 |
+
|
| 49 |
+
def forward(self, t):
|
| 50 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 51 |
+
t_emb = self.mlp(t_freq)
|
| 52 |
+
return t_emb
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class AniGenSparseStructureFlowModel(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
resolution: int,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
in_channels_skl: int,
|
| 61 |
+
model_channels: int,
|
| 62 |
+
model_channels_skl: int,
|
| 63 |
+
cond_channels: int,
|
| 64 |
+
out_channels: int,
|
| 65 |
+
out_channels_skl: int,
|
| 66 |
+
num_blocks: int,
|
| 67 |
+
num_heads: Optional[int] = None,
|
| 68 |
+
num_head_channels: Optional[int] = 64,
|
| 69 |
+
mlp_ratio: float = 4,
|
| 70 |
+
patch_size: int = 2,
|
| 71 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 72 |
+
use_fp16: bool = False,
|
| 73 |
+
use_checkpoint: bool = False,
|
| 74 |
+
share_mod: bool = False,
|
| 75 |
+
qk_rms_norm: bool = False,
|
| 76 |
+
qk_rms_norm_cross: bool = False,
|
| 77 |
+
use_pretrain_branch: bool = True,
|
| 78 |
+
freeze_pretrain_branch: bool = True,
|
| 79 |
+
use_lora_ss: bool = False,
|
| 80 |
+
lora_lr_rate_ss: float = 0.1,
|
| 81 |
+
modules_to_freeze: Optional[List[str]] = ["blocks", "input_layer", "out_layer", "pos_emb", "t_embedder"],
|
| 82 |
+
adapter_ss_to_skl: bool = True,
|
| 83 |
+
adapter_skl_to_ss: bool = True,
|
| 84 |
+
predict_x0: bool = False,
|
| 85 |
+
predict_x0_skl: bool = False,
|
| 86 |
+
t_eps: float = 5e-2,
|
| 87 |
+
t_scale: float = 1e3,
|
| 88 |
+
z_is_global: bool = False,
|
| 89 |
+
z_skl_is_global: bool = False,
|
| 90 |
+
global_token_num: int = 1024,
|
| 91 |
+
global_token_num_skl: int = 1024,
|
| 92 |
+
cross_adapter_every: int = 4,
|
| 93 |
+
skl_cross_from_ss: bool = False,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.resolution = resolution
|
| 97 |
+
self.in_channels = in_channels
|
| 98 |
+
self.in_channels_skl = in_channels_skl
|
| 99 |
+
self.model_channels = model_channels
|
| 100 |
+
self.model_channels_skl = model_channels_skl
|
| 101 |
+
self.cond_channels = cond_channels
|
| 102 |
+
self.out_channels = out_channels
|
| 103 |
+
self.out_channels_skl = out_channels_skl
|
| 104 |
+
self.num_blocks = num_blocks
|
| 105 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 106 |
+
self.mlp_ratio = mlp_ratio
|
| 107 |
+
self.patch_size = patch_size
|
| 108 |
+
self.pe_mode = pe_mode
|
| 109 |
+
self.use_fp16 = use_fp16
|
| 110 |
+
self.use_checkpoint = use_checkpoint
|
| 111 |
+
self.share_mod = share_mod
|
| 112 |
+
self.qk_rms_norm = qk_rms_norm
|
| 113 |
+
self.qk_rms_norm_cross = qk_rms_norm_cross
|
| 114 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 115 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 116 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch or use_lora_ss
|
| 117 |
+
self.use_lora_ss = use_lora_ss
|
| 118 |
+
self.modules_to_freeze = modules_to_freeze
|
| 119 |
+
self.adapter_ss_to_skl = adapter_ss_to_skl
|
| 120 |
+
self.adapter_skl_to_ss = adapter_skl_to_ss
|
| 121 |
+
self.predict_x0 = predict_x0
|
| 122 |
+
self.predict_x0_skl = predict_x0_skl
|
| 123 |
+
self.t_eps = t_eps
|
| 124 |
+
self.t_scale = t_scale
|
| 125 |
+
self.z_is_global = z_is_global
|
| 126 |
+
self.z_skl_is_global = z_skl_is_global
|
| 127 |
+
self.global_token_num = global_token_num
|
| 128 |
+
self.global_token_num_skl = global_token_num_skl
|
| 129 |
+
self.cross_adapter_every = int(cross_adapter_every)
|
| 130 |
+
self.skl_cross_from_ss = skl_cross_from_ss
|
| 131 |
+
|
| 132 |
+
self.t_embedder = TimestepEmbedder(model_channels)
|
| 133 |
+
self.t_embedder_skl = TimestepEmbedder(model_channels_skl)
|
| 134 |
+
if share_mod:
|
| 135 |
+
self.adaLN_modulation = nn.Sequential(
|
| 136 |
+
nn.SiLU(),
|
| 137 |
+
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
| 138 |
+
)
|
| 139 |
+
self.adaLN_modulation_skl = nn.Sequential(
|
| 140 |
+
nn.SiLU(),
|
| 141 |
+
nn.Linear(model_channels_skl, 6 * model_channels_skl, bias=True)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if pe_mode == "ape":
|
| 145 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution // patch_size] * 3], indexing='ij')
|
| 146 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 147 |
+
if self.z_is_global:
|
| 148 |
+
pos_embedder = AbsolutePositionEmbedder(model_channels, 1)
|
| 149 |
+
pos_emb = pos_embedder(torch.arange(self.global_token_num, device=self.device)[:, None])
|
| 150 |
+
else:
|
| 151 |
+
pos_embedder = AbsolutePositionEmbedder(model_channels, 3)
|
| 152 |
+
pos_emb = pos_embedder(coords)
|
| 153 |
+
self.register_buffer("pos_emb", pos_emb)
|
| 154 |
+
if self.z_skl_is_global:
|
| 155 |
+
pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl, 1)
|
| 156 |
+
pos_emb_skl = pos_embedder_skl(torch.arange(self.global_token_num_skl, device=self.device)[:, None])
|
| 157 |
+
else:
|
| 158 |
+
pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl, 3)
|
| 159 |
+
pos_emb_skl = pos_embedder_skl(coords)
|
| 160 |
+
self.register_buffer("pos_emb_skl", pos_emb_skl)
|
| 161 |
+
|
| 162 |
+
self.input_layer = nn.Linear(in_channels * patch_size**3, model_channels)
|
| 163 |
+
self.input_layer_skl = nn.Linear(in_channels_skl * patch_size**3, model_channels_skl)
|
| 164 |
+
|
| 165 |
+
shallow = max(1, num_blocks // 3)
|
| 166 |
+
middle = max(1, num_blocks // 3 * 2)
|
| 167 |
+
self.blocks = nn.ModuleList([
|
| 168 |
+
ModulatedTransformerCrossBlock(
|
| 169 |
+
model_channels,
|
| 170 |
+
cond_channels,
|
| 171 |
+
num_heads=self.num_heads,
|
| 172 |
+
mlp_ratio=self.mlp_ratio,
|
| 173 |
+
attn_mode='full',
|
| 174 |
+
use_checkpoint=self.use_checkpoint,
|
| 175 |
+
use_rope=(pe_mode == "rope"),
|
| 176 |
+
share_mod=share_mod,
|
| 177 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 178 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 179 |
+
use_lora_self=self.use_lora_ss and idx >= middle,
|
| 180 |
+
lora_rank_self=8,
|
| 181 |
+
use_lora_cross=self.use_lora_ss,
|
| 182 |
+
lora_rank_cross=8+(idx // shallow)*8,
|
| 183 |
+
lora_lr_rate=lora_lr_rate_ss,
|
| 184 |
+
)
|
| 185 |
+
for idx in range(num_blocks)
|
| 186 |
+
])
|
| 187 |
+
self.blocks_skl = nn.ModuleList([
|
| 188 |
+
ModulatedTransformerCrossBlock(
|
| 189 |
+
model_channels_skl,
|
| 190 |
+
cond_channels if not self.skl_cross_from_ss else model_channels,
|
| 191 |
+
num_heads=self.num_heads,
|
| 192 |
+
mlp_ratio=self.mlp_ratio,
|
| 193 |
+
attn_mode='full',
|
| 194 |
+
use_checkpoint=self.use_checkpoint,
|
| 195 |
+
use_rope=(pe_mode == "rope"),
|
| 196 |
+
share_mod=share_mod,
|
| 197 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 198 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 199 |
+
use_context_norm=self.skl_cross_from_ss,
|
| 200 |
+
)
|
| 201 |
+
for _ in range(num_blocks)
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
# When using global tokens, ss and skl token counts may differ, so we use cross-attention
|
| 205 |
+
# for information exchange at a configurable frequency.
|
| 206 |
+
self.use_cross_adapter = (self.z_is_global or self.z_skl_is_global) and (
|
| 207 |
+
self.adapter_ss_to_skl or self.adapter_skl_to_ss
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if self.adapter_ss_to_skl and not self.use_cross_adapter:
|
| 211 |
+
self.adapter_ss_to_skl_layers = nn.ModuleList([
|
| 212 |
+
nn.Linear(model_channels, model_channels_skl) for _ in range(num_blocks)
|
| 213 |
+
])
|
| 214 |
+
if self.adapter_skl_to_ss and not self.use_cross_adapter:
|
| 215 |
+
self.adapter_skl_to_ss_layers = nn.ModuleList([
|
| 216 |
+
nn.Linear(model_channels_skl, model_channels) for _ in range(num_blocks)
|
| 217 |
+
])
|
| 218 |
+
|
| 219 |
+
self.cross_adapter_every = max(1, self.cross_adapter_every)
|
| 220 |
+
self.cross_block_indices: List[int] = [
|
| 221 |
+
idx for idx in range(num_blocks) if (idx + 1) % self.cross_adapter_every == 0
|
| 222 |
+
]
|
| 223 |
+
if self.use_cross_adapter and len(self.cross_block_indices) == 0 and num_blocks > 0:
|
| 224 |
+
self.cross_block_indices = [num_blocks - 1]
|
| 225 |
+
if self.use_cross_adapter and len(self.cross_block_indices) > 0:
|
| 226 |
+
if self.adapter_ss_to_skl:
|
| 227 |
+
self.cross_blocks_ss_to_skl = nn.ModuleList([
|
| 228 |
+
ModulatedTransformerCrossBlock(
|
| 229 |
+
model_channels_skl,
|
| 230 |
+
model_channels,
|
| 231 |
+
num_heads=self.num_heads,
|
| 232 |
+
mlp_ratio=self.mlp_ratio,
|
| 233 |
+
attn_mode='full',
|
| 234 |
+
use_checkpoint=self.use_checkpoint,
|
| 235 |
+
use_rope=(pe_mode == "rope"),
|
| 236 |
+
share_mod=share_mod,
|
| 237 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 238 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 239 |
+
)
|
| 240 |
+
for _ in self.cross_block_indices
|
| 241 |
+
])
|
| 242 |
+
self.cross_blocks_ss_to_skl_out = nn.ModuleList([
|
| 243 |
+
nn.Linear(model_channels_skl, model_channels_skl, bias=True)
|
| 244 |
+
for _ in self.cross_block_indices
|
| 245 |
+
])
|
| 246 |
+
if self.adapter_skl_to_ss:
|
| 247 |
+
self.cross_blocks_skl_to_ss = nn.ModuleList([
|
| 248 |
+
ModulatedTransformerCrossBlock(
|
| 249 |
+
model_channels,
|
| 250 |
+
model_channels_skl,
|
| 251 |
+
num_heads=self.num_heads,
|
| 252 |
+
mlp_ratio=self.mlp_ratio,
|
| 253 |
+
attn_mode='full',
|
| 254 |
+
use_checkpoint=self.use_checkpoint,
|
| 255 |
+
use_rope=(pe_mode == "rope"),
|
| 256 |
+
share_mod=share_mod,
|
| 257 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 258 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 259 |
+
)
|
| 260 |
+
for _ in self.cross_block_indices
|
| 261 |
+
])
|
| 262 |
+
self.cross_blocks_skl_to_ss_out = nn.ModuleList([
|
| 263 |
+
nn.Linear(model_channels, model_channels, bias=True)
|
| 264 |
+
for _ in self.cross_block_indices
|
| 265 |
+
])
|
| 266 |
+
|
| 267 |
+
self.out_layer = nn.Linear(model_channels, out_channels * patch_size**3)
|
| 268 |
+
self.out_layer_skl = nn.Linear(model_channels_skl, out_channels_skl * patch_size**3)
|
| 269 |
+
|
| 270 |
+
self.initialize_weights()
|
| 271 |
+
if use_fp16:
|
| 272 |
+
self.convert_to_fp16()
|
| 273 |
+
|
| 274 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 275 |
+
for module in modules_to_freeze:
|
| 276 |
+
if hasattr(self, module):
|
| 277 |
+
mod = getattr(self, module)
|
| 278 |
+
if isinstance(mod, nn.ModuleList):
|
| 279 |
+
for m in mod:
|
| 280 |
+
for name, param in m.named_parameters():
|
| 281 |
+
if 'lora' not in name:
|
| 282 |
+
param.requires_grad = False
|
| 283 |
+
elif isinstance(mod, nn.Module):
|
| 284 |
+
for name, param in mod.named_parameters():
|
| 285 |
+
if 'lora' not in name:
|
| 286 |
+
param.requires_grad = False
|
| 287 |
+
elif isinstance(mod, torch.Tensor):
|
| 288 |
+
if mod.requires_grad:
|
| 289 |
+
mod.requires_grad = False
|
| 290 |
+
|
| 291 |
+
@property
|
| 292 |
+
def device(self) -> torch.device:
|
| 293 |
+
"""
|
| 294 |
+
Return the device of the model.
|
| 295 |
+
"""
|
| 296 |
+
return next(self.parameters()).device
|
| 297 |
+
|
| 298 |
+
def convert_to_fp16(self) -> None:
|
| 299 |
+
"""
|
| 300 |
+
Convert the torso of the model to float16.
|
| 301 |
+
"""
|
| 302 |
+
self.blocks.apply(convert_module_to_f16)
|
| 303 |
+
self.blocks_skl.apply(convert_module_to_f16)
|
| 304 |
+
if hasattr(self, "adapter_ss_to_skl_layers"):
|
| 305 |
+
self.adapter_ss_to_skl_layers.apply(convert_module_to_f16)
|
| 306 |
+
if hasattr(self, "adapter_skl_to_ss_layers"):
|
| 307 |
+
self.adapter_skl_to_ss_layers.apply(convert_module_to_f16)
|
| 308 |
+
if getattr(self, "use_cross_adapter", False):
|
| 309 |
+
if hasattr(self, "cross_blocks_ss_to_skl"):
|
| 310 |
+
self.cross_blocks_ss_to_skl.apply(convert_module_to_f16)
|
| 311 |
+
self.cross_blocks_ss_to_skl_out.apply(convert_module_to_f16)
|
| 312 |
+
if hasattr(self, "cross_blocks_skl_to_ss"):
|
| 313 |
+
self.cross_blocks_skl_to_ss.apply(convert_module_to_f16)
|
| 314 |
+
self.cross_blocks_skl_to_ss_out.apply(convert_module_to_f16)
|
| 315 |
+
|
| 316 |
+
def convert_to_fp32(self) -> None:
|
| 317 |
+
"""
|
| 318 |
+
Convert the torso of the model to float32.
|
| 319 |
+
"""
|
| 320 |
+
self.blocks.apply(convert_module_to_f32)
|
| 321 |
+
self.blocks_skl.apply(convert_module_to_f32)
|
| 322 |
+
if hasattr(self, "adapter_ss_to_skl_layers"):
|
| 323 |
+
self.adapter_ss_to_skl_layers.apply(convert_module_to_f32)
|
| 324 |
+
if hasattr(self, "adapter_skl_to_ss_layers"):
|
| 325 |
+
self.adapter_skl_to_ss_layers.apply(convert_module_to_f32)
|
| 326 |
+
if getattr(self, "use_cross_adapter", False):
|
| 327 |
+
if hasattr(self, "cross_blocks_ss_to_skl"):
|
| 328 |
+
self.cross_blocks_ss_to_skl.apply(convert_module_to_f32)
|
| 329 |
+
self.cross_blocks_ss_to_skl_out.apply(convert_module_to_f32)
|
| 330 |
+
if hasattr(self, "cross_blocks_skl_to_ss"):
|
| 331 |
+
self.cross_blocks_skl_to_ss.apply(convert_module_to_f32)
|
| 332 |
+
self.cross_blocks_skl_to_ss_out.apply(convert_module_to_f32)
|
| 333 |
+
|
| 334 |
+
def initialize_weights(self) -> None:
|
| 335 |
+
# Initialize transformer layers:
|
| 336 |
+
def _basic_init(module):
|
| 337 |
+
if isinstance(module, nn.Linear):
|
| 338 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 339 |
+
if module.bias is not None:
|
| 340 |
+
nn.init.constant_(module.bias, 0)
|
| 341 |
+
self.apply(_basic_init)
|
| 342 |
+
|
| 343 |
+
# Initialize timestep embedding MLP:
|
| 344 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 345 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 346 |
+
nn.init.normal_(self.t_embedder_skl.mlp[0].weight, std=0.02)
|
| 347 |
+
nn.init.normal_(self.t_embedder_skl.mlp[2].weight, std=0.02)
|
| 348 |
+
|
| 349 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 350 |
+
if self.share_mod:
|
| 351 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 352 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 353 |
+
nn.init.constant_(self.adaLN_modulation_skl[-1].weight, 0)
|
| 354 |
+
nn.init.constant_(self.adaLN_modulation_skl[-1].bias, 0)
|
| 355 |
+
else:
|
| 356 |
+
for block in self.blocks:
|
| 357 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 358 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 359 |
+
for block in self.blocks_skl:
|
| 360 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 361 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 362 |
+
|
| 363 |
+
# Zero-out output layers:
|
| 364 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 365 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 366 |
+
nn.init.constant_(self.out_layer_skl.weight, 0)
|
| 367 |
+
nn.init.constant_(self.out_layer_skl.bias, 0)
|
| 368 |
+
|
| 369 |
+
# Zero-out adapter layers if exist
|
| 370 |
+
if hasattr(self, "adapter_ss_to_skl_layers"):
|
| 371 |
+
for layer in self.adapter_ss_to_skl_layers:
|
| 372 |
+
nn.init.constant_(layer.weight, 0)
|
| 373 |
+
nn.init.constant_(layer.bias, 0)
|
| 374 |
+
if hasattr(self, "adapter_skl_to_ss_layers"):
|
| 375 |
+
for layer in self.adapter_skl_to_ss_layers:
|
| 376 |
+
nn.init.constant_(layer.weight, 0)
|
| 377 |
+
nn.init.constant_(layer.bias, 0)
|
| 378 |
+
|
| 379 |
+
# Zero-out cross adapter output projections (so we can safely finetune from pretrained ckpt)
|
| 380 |
+
if getattr(self, "use_cross_adapter", False):
|
| 381 |
+
if hasattr(self, "cross_blocks_ss_to_skl_out"):
|
| 382 |
+
for layer in self.cross_blocks_ss_to_skl_out:
|
| 383 |
+
nn.init.constant_(layer.weight, 0)
|
| 384 |
+
nn.init.constant_(layer.bias, 0)
|
| 385 |
+
if hasattr(self, "cross_blocks_skl_to_ss_out"):
|
| 386 |
+
for layer in self.cross_blocks_skl_to_ss_out:
|
| 387 |
+
nn.init.constant_(layer.weight, 0)
|
| 388 |
+
nn.init.constant_(layer.bias, 0)
|
| 389 |
+
|
| 390 |
+
def forward(self, x: torch.Tensor, x_skl: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 391 |
+
if not self.z_is_global:
|
| 392 |
+
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
| 393 |
+
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
| 394 |
+
if not self.z_skl_is_global:
|
| 395 |
+
assert [*x_skl.shape] == [x_skl.shape[0], self.in_channels_skl, *[self.resolution] * 3], \
|
| 396 |
+
f"Input shape mismatch, got {x_skl.shape}, expected {[x_skl.shape[0], self.in_channels_skl, *[self.resolution] * 3]}"
|
| 397 |
+
|
| 398 |
+
if self.predict_x0:
|
| 399 |
+
xt = x.clone()
|
| 400 |
+
if self.predict_x0_skl:
|
| 401 |
+
xt_skl = x_skl.clone()
|
| 402 |
+
|
| 403 |
+
if not self.z_is_global:
|
| 404 |
+
h = patchify(x, self.patch_size)
|
| 405 |
+
h = h.view(*h.shape[:2], -1).permute(0, 2, 1).contiguous()
|
| 406 |
+
else:
|
| 407 |
+
h = x
|
| 408 |
+
if not self.z_skl_is_global:
|
| 409 |
+
h_skl = patchify(x_skl, self.patch_size)
|
| 410 |
+
h_skl = h_skl.view(*h_skl.shape[:2], -1).permute(0, 2, 1).contiguous()
|
| 411 |
+
else:
|
| 412 |
+
h_skl = x_skl
|
| 413 |
+
|
| 414 |
+
h = self.input_layer(h)
|
| 415 |
+
h = h + self.pos_emb[None]
|
| 416 |
+
h_skl = self.input_layer_skl(h_skl)
|
| 417 |
+
h_skl = h_skl + self.pos_emb_skl[None]
|
| 418 |
+
|
| 419 |
+
t_emb = self.t_embedder(t)
|
| 420 |
+
t_emb_skl = self.t_embedder_skl(t)
|
| 421 |
+
if self.share_mod:
|
| 422 |
+
t_emb = self.adaLN_modulation(t_emb)
|
| 423 |
+
t_emb_skl = self.adaLN_modulation_skl(t_emb_skl)
|
| 424 |
+
t_emb = t_emb.type(self.dtype)
|
| 425 |
+
t_emb_skl = t_emb_skl.type(self.dtype)
|
| 426 |
+
|
| 427 |
+
h = h.type(self.dtype)
|
| 428 |
+
h_skl = h_skl.type(self.dtype)
|
| 429 |
+
cond = cond.type(self.dtype)
|
| 430 |
+
|
| 431 |
+
cross_pos_to_idx = None
|
| 432 |
+
if self.use_cross_adapter and len(self.cross_block_indices) > 0:
|
| 433 |
+
cross_pos_to_idx = {bidx: cidx for cidx, bidx in enumerate(self.cross_block_indices)}
|
| 434 |
+
|
| 435 |
+
for idx, block, block_skl in zip(range(len(self.blocks)), self.blocks, self.blocks_skl):
|
| 436 |
+
f = block(h, t_emb, cond)
|
| 437 |
+
f_skl = block_skl(h_skl, t_emb_skl, h if self.skl_cross_from_ss else cond)
|
| 438 |
+
|
| 439 |
+
if self.use_cross_adapter and cross_pos_to_idx is not None and idx in cross_pos_to_idx:
|
| 440 |
+
cidx = cross_pos_to_idx[idx]
|
| 441 |
+
if self.adapter_ss_to_skl:
|
| 442 |
+
out_skl = self.cross_blocks_ss_to_skl[cidx](f_skl, t_emb_skl, f)
|
| 443 |
+
h_skl = f_skl + self.cross_blocks_ss_to_skl_out[cidx](out_skl - f_skl)
|
| 444 |
+
else:
|
| 445 |
+
h_skl = f_skl
|
| 446 |
+
|
| 447 |
+
if self.adapter_skl_to_ss:
|
| 448 |
+
out = self.cross_blocks_skl_to_ss[cidx](f, t_emb, f_skl)
|
| 449 |
+
h = f + self.cross_blocks_skl_to_ss_out[cidx](out - f)
|
| 450 |
+
else:
|
| 451 |
+
h = f
|
| 452 |
+
else:
|
| 453 |
+
# Non-global (or no cross block at this idx): keep previous behavior.
|
| 454 |
+
if self.adapter_ss_to_skl and (not self.use_cross_adapter):
|
| 455 |
+
h_skl = f_skl + self.adapter_ss_to_skl_layers[idx](f)
|
| 456 |
+
else:
|
| 457 |
+
h_skl = f_skl
|
| 458 |
+
|
| 459 |
+
if self.adapter_skl_to_ss and (not self.use_cross_adapter):
|
| 460 |
+
h = f + self.adapter_skl_to_ss_layers[idx](f_skl)
|
| 461 |
+
else:
|
| 462 |
+
h = f
|
| 463 |
+
h = h.type(x.dtype)
|
| 464 |
+
h = F.layer_norm(h, h.shape[-1:])
|
| 465 |
+
h = self.out_layer(h)
|
| 466 |
+
h_skl = h_skl.type(x_skl.dtype)
|
| 467 |
+
h_skl = F.layer_norm(h_skl, h_skl.shape[-1:])
|
| 468 |
+
h_skl = self.out_layer_skl(h_skl)
|
| 469 |
+
|
| 470 |
+
if not self.z_is_global:
|
| 471 |
+
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution // self.patch_size] * 3)
|
| 472 |
+
h = unpatchify(h, self.patch_size).contiguous()
|
| 473 |
+
|
| 474 |
+
if not self.z_skl_is_global:
|
| 475 |
+
h_skl = h_skl.permute(0, 2, 1).view(h_skl.shape[0], h_skl.shape[2], *[self.resolution // self.patch_size] * 3)
|
| 476 |
+
h_skl = unpatchify(h_skl, self.patch_size).contiguous()
|
| 477 |
+
|
| 478 |
+
if self.predict_x0:
|
| 479 |
+
t_normalized = t / self.t_scale
|
| 480 |
+
factor = (1 / t_normalized.clamp_min(self.t_eps)).reshape([t.shape[0], *([1] * (x.dim() - 1))])
|
| 481 |
+
h = (xt - h) * factor
|
| 482 |
+
if self.predict_x0_skl:
|
| 483 |
+
t_normalized = t / self.t_scale
|
| 484 |
+
factor = (1 / t_normalized.clamp_min(self.t_eps)).reshape([t.shape[0], *([1] * (x_skl.dim() - 1))])
|
| 485 |
+
h_skl = (xt_skl - h_skl) * factor
|
| 486 |
+
|
| 487 |
+
return h, h_skl
|
anigen/models/anigen_sparse_structure_vae.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from ..modules.norm import GroupNorm32, ChannelLayerNorm32
|
| 6 |
+
from ..modules.spatial import pixel_shuffle_3d
|
| 7 |
+
from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 8 |
+
from ..modules.transformer import FeedForwardNet, TransformerBlock, TransformerCrossBlock, AbsolutePositionEmbedder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
| 12 |
+
"""
|
| 13 |
+
Return a normalization layer.
|
| 14 |
+
"""
|
| 15 |
+
if norm_type == "group":
|
| 16 |
+
return GroupNorm32(32, *args, **kwargs)
|
| 17 |
+
elif norm_type == "layer":
|
| 18 |
+
return ChannelLayerNorm32(*args, **kwargs)
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError(f"Invalid norm type {norm_type}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResBlock3d(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
channels: int,
|
| 27 |
+
out_channels: Optional[int] = None,
|
| 28 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.channels = channels
|
| 32 |
+
self.out_channels = out_channels or channels
|
| 33 |
+
|
| 34 |
+
self.norm1 = norm_layer(norm_type, channels)
|
| 35 |
+
self.norm2 = norm_layer(norm_type, self.out_channels)
|
| 36 |
+
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
| 37 |
+
self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
|
| 38 |
+
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
h = self.norm1(x)
|
| 42 |
+
h = F.silu(h)
|
| 43 |
+
h = self.conv1(h)
|
| 44 |
+
h = self.norm2(h)
|
| 45 |
+
h = F.silu(h)
|
| 46 |
+
h = self.conv2(h)
|
| 47 |
+
h = h + self.skip_connection(x)
|
| 48 |
+
return h
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DownsampleBlock3d(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
in_channels: int,
|
| 55 |
+
out_channels: int,
|
| 56 |
+
mode: Literal["conv", "avgpool"] = "conv",
|
| 57 |
+
):
|
| 58 |
+
assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
|
| 59 |
+
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.in_channels = in_channels
|
| 62 |
+
self.out_channels = out_channels
|
| 63 |
+
|
| 64 |
+
if mode == "conv":
|
| 65 |
+
self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
| 66 |
+
elif mode == "avgpool":
|
| 67 |
+
assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
if hasattr(self, "conv"):
|
| 71 |
+
return self.conv(x)
|
| 72 |
+
else:
|
| 73 |
+
return F.avg_pool3d(x, 2)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class UpsampleBlock3d(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
in_channels: int,
|
| 80 |
+
out_channels: int,
|
| 81 |
+
mode: Literal["conv", "nearest"] = "conv",
|
| 82 |
+
):
|
| 83 |
+
assert mode in ["conv", "nearest"], f"Invalid mode {mode}"
|
| 84 |
+
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.in_channels = in_channels
|
| 87 |
+
self.out_channels = out_channels
|
| 88 |
+
|
| 89 |
+
if mode == "conv":
|
| 90 |
+
self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1)
|
| 91 |
+
elif mode == "nearest":
|
| 92 |
+
assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels"
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
if hasattr(self, "conv"):
|
| 96 |
+
x = self.conv(x)
|
| 97 |
+
return pixel_shuffle_3d(x, 2)
|
| 98 |
+
else:
|
| 99 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AniGenSparseStructureEncoder(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
Encoder for Sparse Structure (\mathcal{E}_S in the paper Sec. 3.3).
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
in_channels (int): Channels of the input.
|
| 108 |
+
latent_channels (int): Channels of the latent representation.
|
| 109 |
+
num_res_blocks (int): Number of residual blocks at each resolution.
|
| 110 |
+
channels (List[int]): Channels of the encoder blocks.
|
| 111 |
+
num_res_blocks_middle (int): Number of residual blocks in the middle.
|
| 112 |
+
norm_type (Literal["group", "layer"]): Type of normalization layer.
|
| 113 |
+
use_fp16 (bool): Whether to use FP16.
|
| 114 |
+
"""
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
in_channels: int,
|
| 118 |
+
in_channels_skl: int,
|
| 119 |
+
latent_channels: int,
|
| 120 |
+
latent_channels_skl: int,
|
| 121 |
+
num_res_blocks: int,
|
| 122 |
+
channels: List[int],
|
| 123 |
+
num_res_blocks_middle: int = 2,
|
| 124 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 125 |
+
use_fp16: bool = False,
|
| 126 |
+
encode_global: bool = False,
|
| 127 |
+
global_token_num: int = 1024,
|
| 128 |
+
encode_global_skl: bool = True,
|
| 129 |
+
global_token_num_skl: int = 1024,
|
| 130 |
+
use_pretrain_branch: bool = True,
|
| 131 |
+
freeze_pretrain_branch: bool = True,
|
| 132 |
+
modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"],
|
| 133 |
+
latent_denoising: bool = False,
|
| 134 |
+
latent_denoising_skl: bool = True,
|
| 135 |
+
normalize_z: bool = False,
|
| 136 |
+
normalize_z_skl: bool = True,
|
| 137 |
+
normalize_scale: float = 1.0
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.in_channels = in_channels
|
| 141 |
+
self.in_channels_skl = in_channels_skl
|
| 142 |
+
self.latent_channels = latent_channels
|
| 143 |
+
self.latent_channels_skl = latent_channels_skl
|
| 144 |
+
self.num_res_blocks = num_res_blocks
|
| 145 |
+
self.channels = channels
|
| 146 |
+
self.num_res_blocks_middle = num_res_blocks_middle
|
| 147 |
+
self.norm_type = norm_type
|
| 148 |
+
self.use_fp16 = use_fp16
|
| 149 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 150 |
+
self.encode_global = encode_global
|
| 151 |
+
self.global_token_num = global_token_num
|
| 152 |
+
self.encode_global_skl = encode_global_skl
|
| 153 |
+
self.global_token_num_skl = global_token_num_skl
|
| 154 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 155 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch
|
| 156 |
+
self.latent_denoising = latent_denoising
|
| 157 |
+
self.latent_denoising_skl = latent_denoising_skl
|
| 158 |
+
self.normalize_latent = normalize_z and latent_denoising
|
| 159 |
+
self.normalize_latent_skl = normalize_z_skl and latent_denoising_skl
|
| 160 |
+
self.normalize_scale = normalize_scale
|
| 161 |
+
|
| 162 |
+
self.input_layer = nn.Conv3d(self.in_channels, channels[0], 3, padding=1)
|
| 163 |
+
self.input_layer_skl = nn.Conv3d(self.in_channels_skl, channels[0], 3, padding=1)
|
| 164 |
+
|
| 165 |
+
self.blocks = nn.ModuleList([])
|
| 166 |
+
self.blocks_skl = nn.ModuleList([])
|
| 167 |
+
for i, ch in enumerate(channels):
|
| 168 |
+
self.blocks.extend([
|
| 169 |
+
ResBlock3d(ch, ch)
|
| 170 |
+
for _ in range(num_res_blocks)
|
| 171 |
+
])
|
| 172 |
+
self.blocks_skl.extend([
|
| 173 |
+
ResBlock3d(ch, ch)
|
| 174 |
+
for _ in range(num_res_blocks)
|
| 175 |
+
])
|
| 176 |
+
if i < len(channels) - 1:
|
| 177 |
+
self.blocks.append(
|
| 178 |
+
DownsampleBlock3d(ch, channels[i+1])
|
| 179 |
+
)
|
| 180 |
+
self.blocks_skl.append(
|
| 181 |
+
DownsampleBlock3d(ch, channels[i+1])
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.middle_block = nn.Sequential(*[
|
| 185 |
+
ResBlock3d(channels[-1], channels[-1])
|
| 186 |
+
for _ in range(num_res_blocks_middle)
|
| 187 |
+
])
|
| 188 |
+
self.middle_block_skl = nn.Sequential(*[
|
| 189 |
+
ResBlock3d(channels[-1] if _ == 0 else channels[-1], channels[-1])
|
| 190 |
+
for _ in range(num_res_blocks_middle)
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
if self.encode_global:
|
| 194 |
+
# Initial Tokens and PE
|
| 195 |
+
self.init_tokens_ss = nn.Parameter(torch.zeros(1, global_token_num, channels[-1]))
|
| 196 |
+
pos_embedder = AbsolutePositionEmbedder(channels[-1], 1)
|
| 197 |
+
coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1)
|
| 198 |
+
tokens_pos_emb = pos_embedder(coords)
|
| 199 |
+
self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb)
|
| 200 |
+
# Grids PE
|
| 201 |
+
upsample_factor = 2 ** (len(channels) - 1)
|
| 202 |
+
self.base_size_ss = 64 // upsample_factor
|
| 203 |
+
pos_embedder = AbsolutePositionEmbedder(channels[-1], 3)
|
| 204 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij')
|
| 205 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 206 |
+
grid_pos_emb = pos_embedder(coords)
|
| 207 |
+
self.register_buffer("grid_pos_emb_ss", grid_pos_emb)
|
| 208 |
+
# Token projection layer
|
| 209 |
+
self.token_proj_ss = nn.Linear(channels[-1]*2, channels[-1])
|
| 210 |
+
|
| 211 |
+
# Out layers
|
| 212 |
+
self.out_layer = nn.ModuleList(
|
| 213 |
+
[TransformerCrossBlock(
|
| 214 |
+
channels=channels[-1],
|
| 215 |
+
ctx_channels=channels[-1]*2,
|
| 216 |
+
out_channels=channels[-1],
|
| 217 |
+
num_heads=16,
|
| 218 |
+
attn_mode="full",
|
| 219 |
+
qkv_bias=False,
|
| 220 |
+
x_is_query=False)] +
|
| 221 |
+
[TransformerBlock(
|
| 222 |
+
channels=channels[-1],
|
| 223 |
+
out_channels=channels[-1],
|
| 224 |
+
num_heads=16,
|
| 225 |
+
attn_mode="full",
|
| 226 |
+
qkv_bias=False,
|
| 227 |
+
) for _ in range(4)] +
|
| 228 |
+
[FeedForwardNet(
|
| 229 |
+
channels=channels[-1],
|
| 230 |
+
out_channels=latent_channels*2 if not self.latent_denoising else latent_channels)]
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
self.out_layer = nn.Sequential(
|
| 234 |
+
norm_layer(norm_type, channels[-1]),
|
| 235 |
+
nn.SiLU(),
|
| 236 |
+
nn.Conv3d(channels[-1], latent_channels*2 if not self.latent_denoising else latent_channels, 3, padding=1)
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if self.encode_global_skl:
|
| 240 |
+
# Initial Tokens and PE
|
| 241 |
+
self.init_tokens = nn.Parameter(torch.zeros(1, global_token_num_skl, channels[-1]))
|
| 242 |
+
pos_embedder = AbsolutePositionEmbedder(channels[-1], 1)
|
| 243 |
+
coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1)
|
| 244 |
+
tokens_pos_emb = pos_embedder(coords)
|
| 245 |
+
self.register_buffer('tokens_pos_emb', tokens_pos_emb)
|
| 246 |
+
# Grids PE
|
| 247 |
+
upsample_factor = 2 ** (len(channels) - 1)
|
| 248 |
+
self.base_size = 64 // upsample_factor
|
| 249 |
+
pos_embedder = AbsolutePositionEmbedder(channels[-1], 3)
|
| 250 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij')
|
| 251 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 252 |
+
grid_pos_emb = pos_embedder(coords)
|
| 253 |
+
self.register_buffer("grid_pos_emb", grid_pos_emb)
|
| 254 |
+
# Token projection layer
|
| 255 |
+
self.token_proj = nn.Linear(channels[-1]*2, channels[-1])
|
| 256 |
+
|
| 257 |
+
# Out layers
|
| 258 |
+
self.out_layer_skl = nn.ModuleList(
|
| 259 |
+
[TransformerCrossBlock(
|
| 260 |
+
channels=channels[-1],
|
| 261 |
+
ctx_channels=channels[-1]*2,
|
| 262 |
+
out_channels=channels[-1],
|
| 263 |
+
num_heads=16,
|
| 264 |
+
attn_mode="full",
|
| 265 |
+
qkv_bias=False,
|
| 266 |
+
x_is_query=False)] +
|
| 267 |
+
[TransformerBlock(
|
| 268 |
+
channels=channels[-1],
|
| 269 |
+
out_channels=channels[-1],
|
| 270 |
+
num_heads=16,
|
| 271 |
+
attn_mode="full",
|
| 272 |
+
qkv_bias=False,
|
| 273 |
+
) for _ in range(4)] +
|
| 274 |
+
[FeedForwardNet(
|
| 275 |
+
channels=channels[-1],
|
| 276 |
+
out_channels=latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl)]
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
self.out_layer_skl = nn.Sequential(
|
| 280 |
+
norm_layer(norm_type, channels[-1]),
|
| 281 |
+
nn.SiLU(),
|
| 282 |
+
nn.Conv3d(channels[-1], latent_channels_skl*2 if not self.latent_denoising_skl else latent_channels_skl, 3, padding=1)
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.initialize_weights()
|
| 286 |
+
if use_fp16:
|
| 287 |
+
self.convert_to_fp16()
|
| 288 |
+
|
| 289 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 290 |
+
# Freeze: self.input_layer, self.blocks, self.middle_block, self.out_layer
|
| 291 |
+
for module in modules_to_freeze:
|
| 292 |
+
if hasattr(self, module):
|
| 293 |
+
mod = getattr(self, module)
|
| 294 |
+
if isinstance(mod, nn.ModuleList):
|
| 295 |
+
for m in mod:
|
| 296 |
+
for param in m.parameters():
|
| 297 |
+
param.requires_grad = False
|
| 298 |
+
else:
|
| 299 |
+
for param in mod.parameters():
|
| 300 |
+
param.requires_grad = False
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def device(self) -> torch.device:
|
| 304 |
+
"""
|
| 305 |
+
Return the device of the model.
|
| 306 |
+
"""
|
| 307 |
+
return next(self.parameters()).device
|
| 308 |
+
|
| 309 |
+
def convert_to_fp16(self) -> None:
|
| 310 |
+
"""
|
| 311 |
+
Convert the torso of the model to float16.
|
| 312 |
+
"""
|
| 313 |
+
self.use_fp16 = True
|
| 314 |
+
self.dtype = torch.float16
|
| 315 |
+
self.blocks.apply(convert_module_to_f16)
|
| 316 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 317 |
+
self.blocks_skl.apply(convert_module_to_f16)
|
| 318 |
+
self.middle_block_skl.apply(convert_module_to_f16)
|
| 319 |
+
if self.encode_global_skl:
|
| 320 |
+
self.token_proj.apply(convert_module_to_f16)
|
| 321 |
+
self.out_layer_skl.apply(convert_module_to_f16)
|
| 322 |
+
if self.encode_global:
|
| 323 |
+
self.token_proj_ss.apply(convert_module_to_f16)
|
| 324 |
+
self.out_layer.apply(convert_module_to_f16)
|
| 325 |
+
|
| 326 |
+
def convert_to_fp32(self) -> None:
|
| 327 |
+
"""
|
| 328 |
+
Convert the torso of the model to float32.
|
| 329 |
+
"""
|
| 330 |
+
self.use_fp16 = False
|
| 331 |
+
self.dtype = torch.float32
|
| 332 |
+
self.blocks.apply(convert_module_to_f32)
|
| 333 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 334 |
+
self.blocks_skl.apply(convert_module_to_f32)
|
| 335 |
+
self.middle_block_skl.apply(convert_module_to_f32)
|
| 336 |
+
if self.encode_global_skl:
|
| 337 |
+
self.token_proj.apply(convert_module_to_f32)
|
| 338 |
+
self.out_layer_skl.apply(convert_module_to_f32)
|
| 339 |
+
if self.encode_global:
|
| 340 |
+
self.token_proj_ss.apply(convert_module_to_f32)
|
| 341 |
+
self.out_layer.apply(convert_module_to_f32)
|
| 342 |
+
|
| 343 |
+
def initialize_weights(self) -> None:
|
| 344 |
+
# Initialize transformer layers:
|
| 345 |
+
def _basic_init(module):
|
| 346 |
+
if isinstance(module, nn.Linear):
|
| 347 |
+
torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear')
|
| 348 |
+
if module.bias is not None:
|
| 349 |
+
nn.init.constant_(module.bias, 0)
|
| 350 |
+
self.apply(_basic_init)
|
| 351 |
+
|
| 352 |
+
def forward(self, x: torch.Tensor, x_skl: torch.Tensor = None, sample_posterior: bool = False, return_raw: bool = False) -> torch.Tensor:
|
| 353 |
+
h = self.input_layer(x)
|
| 354 |
+
h = h.type(self.dtype)
|
| 355 |
+
h_skl = self.input_layer_skl(x_skl)
|
| 356 |
+
h_skl = h_skl.type(self.dtype)
|
| 357 |
+
|
| 358 |
+
for block, block_skl in zip(self.blocks, self.blocks_skl):
|
| 359 |
+
h_skl = block_skl(h_skl)
|
| 360 |
+
h = block(h)
|
| 361 |
+
h_skl = self.middle_block_skl(h_skl)
|
| 362 |
+
h = self.middle_block(h)
|
| 363 |
+
|
| 364 |
+
if self.encode_global:
|
| 365 |
+
B, C, D, H, W = h.shape
|
| 366 |
+
h = h.view(B, C, D*H*W).permute(0, 2, 1) # B, N, C
|
| 367 |
+
h = torch.cat([h, self.grid_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(h.dtype)
|
| 368 |
+
init_tokens = torch.cat([self.init_tokens_ss, self.tokens_pos_emb_ss[None].expand_as(self.init_tokens_ss)], dim=-1).type(h.dtype)
|
| 369 |
+
tokens = self.token_proj_ss(init_tokens.expand(B, -1, -1))
|
| 370 |
+
h = self.out_layer[0](tokens, h) # B, global_token_num, C
|
| 371 |
+
for layer in self.out_layer[1:]:
|
| 372 |
+
h = layer(h)
|
| 373 |
+
h = h.type(x.dtype)
|
| 374 |
+
if self.latent_denoising:
|
| 375 |
+
if self.normalize_latent:
|
| 376 |
+
h = nn.functional.normalize(h, dim=-1) * self.normalize_scale
|
| 377 |
+
mean = h
|
| 378 |
+
logvar = torch.zeros_like(h)
|
| 379 |
+
else:
|
| 380 |
+
mean, logvar = h.chunk(2, dim=2) # B, global_token_num, C
|
| 381 |
+
if sample_posterior and not self.latent_denoising:
|
| 382 |
+
std = torch.exp(0.5 * logvar)
|
| 383 |
+
z = mean + std * torch.randn_like(std)
|
| 384 |
+
else:
|
| 385 |
+
z = mean
|
| 386 |
+
else:
|
| 387 |
+
h = h.type(x.dtype)
|
| 388 |
+
h = self.out_layer(h)
|
| 389 |
+
if self.latent_denoising:
|
| 390 |
+
if self.normalize_latent:
|
| 391 |
+
h = nn.functional.normalize(h, dim=1) * self.normalize_scale
|
| 392 |
+
mean = h
|
| 393 |
+
logvar = torch.zeros_like(h)
|
| 394 |
+
else:
|
| 395 |
+
mean, logvar = h.chunk(2, dim=1)
|
| 396 |
+
if sample_posterior and not self.latent_denoising:
|
| 397 |
+
std = torch.exp(0.5 * logvar)
|
| 398 |
+
z = mean + std * torch.randn_like(std)
|
| 399 |
+
else:
|
| 400 |
+
z = mean
|
| 401 |
+
|
| 402 |
+
if self.encode_global_skl:
|
| 403 |
+
B, C, D, H, W = h_skl.shape
|
| 404 |
+
h_skl = h_skl.view(B, C, D*H*W).permute(0, 2, 1) # B, N, C
|
| 405 |
+
h_skl = torch.cat([h_skl, self.grid_pos_emb[None].expand(B, -1, -1)], dim=-1).type(h_skl.dtype)
|
| 406 |
+
init_tokens = torch.cat([self.init_tokens, self.tokens_pos_emb[None].expand_as(self.init_tokens)], dim=-1).type(h_skl.dtype)
|
| 407 |
+
tokens = self.token_proj(init_tokens.expand(B, -1, -1))
|
| 408 |
+
h_skl = self.out_layer_skl[0](tokens, h_skl) # B, global_token_num_skl, C
|
| 409 |
+
for layer in self.out_layer_skl[1:]:
|
| 410 |
+
h_skl = layer(h_skl)
|
| 411 |
+
h_skl = h_skl.type(x_skl.dtype)
|
| 412 |
+
if self.latent_denoising_skl:
|
| 413 |
+
if self.normalize_latent_skl:
|
| 414 |
+
h_skl = nn.functional.normalize(h_skl, dim=-1) * self.normalize_scale
|
| 415 |
+
mean_skl = h_skl
|
| 416 |
+
logvar_skl = torch.zeros_like(h_skl)
|
| 417 |
+
else:
|
| 418 |
+
mean_skl, logvar_skl = h_skl.chunk(2, dim=2) # B, global_token_num_skl, C
|
| 419 |
+
if sample_posterior and not self.latent_denoising_skl:
|
| 420 |
+
std_skl = torch.exp(0.5 * logvar_skl)
|
| 421 |
+
z_skl = mean_skl + std_skl * torch.randn_like(std_skl)
|
| 422 |
+
else:
|
| 423 |
+
z_skl = mean_skl
|
| 424 |
+
else:
|
| 425 |
+
h_skl = h_skl.type(x_skl.dtype)
|
| 426 |
+
h_skl = self.out_layer_skl(h_skl)
|
| 427 |
+
if self.latent_denoising_skl:
|
| 428 |
+
if self.normalize_latent_skl:
|
| 429 |
+
h_skl = nn.functional.normalize(h_skl, dim=1) * self.normalize_scale
|
| 430 |
+
mean_skl = h_skl
|
| 431 |
+
logvar_skl = torch.zeros_like(h_skl)
|
| 432 |
+
else:
|
| 433 |
+
mean_skl, logvar_skl = h_skl.chunk(2, dim=1)
|
| 434 |
+
if sample_posterior and not self.latent_denoising_skl:
|
| 435 |
+
std_skl = torch.exp(0.5 * logvar_skl)
|
| 436 |
+
z_skl = mean_skl + std_skl * torch.randn_like(std_skl)
|
| 437 |
+
else:
|
| 438 |
+
z_skl = mean_skl
|
| 439 |
+
|
| 440 |
+
if self.latent_denoising:
|
| 441 |
+
mean = mean.detach()
|
| 442 |
+
if self.latent_denoising_skl:
|
| 443 |
+
mean_skl = mean_skl.detach()
|
| 444 |
+
|
| 445 |
+
if return_raw:
|
| 446 |
+
return z, mean, logvar, z_skl, mean_skl, logvar_skl
|
| 447 |
+
return z, z_skl
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class AniGenSparseStructureDecoder(nn.Module):
|
| 451 |
+
"""
|
| 452 |
+
Decoder for Sparse Structure (\mathcal{D}_S in the paper Sec. 3.3).
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
out_channels (int): Channels of the output.
|
| 456 |
+
latent_channels (int): Channels of the latent representation.
|
| 457 |
+
num_res_blocks (int): Number of residual blocks at each resolution.
|
| 458 |
+
channels (List[int]): Channels of the decoder blocks.
|
| 459 |
+
num_res_blocks_middle (int): Number of residual blocks in the middle.
|
| 460 |
+
norm_type (Literal["group", "layer"]): Type of normalization layer.
|
| 461 |
+
use_fp16 (bool): Whether to use FP16.
|
| 462 |
+
"""
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
out_channels: int,
|
| 466 |
+
out_channels_skl: int,
|
| 467 |
+
latent_channels: int,
|
| 468 |
+
latent_channels_skl: int,
|
| 469 |
+
num_res_blocks: int,
|
| 470 |
+
channels: List[int],
|
| 471 |
+
num_res_blocks_middle: int = 2,
|
| 472 |
+
norm_type: Literal["group", "layer"] = "layer",
|
| 473 |
+
use_fp16: bool = False,
|
| 474 |
+
encode_global: bool = False,
|
| 475 |
+
global_token_num: int = 1024,
|
| 476 |
+
encode_global_skl: bool = True,
|
| 477 |
+
global_token_num_skl: int = 1024,
|
| 478 |
+
use_pretrain_branch: bool = True,
|
| 479 |
+
freeze_pretrain_branch: bool = True,
|
| 480 |
+
modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "middle_block", "out_layer"],
|
| 481 |
+
normalize_z: bool = False,
|
| 482 |
+
normalize_z_skl: bool = True,
|
| 483 |
+
normalize_scale: float = 1.0,
|
| 484 |
+
):
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.out_channels = out_channels
|
| 487 |
+
self.out_channels_skl = out_channels_skl
|
| 488 |
+
self.latent_channels = latent_channels
|
| 489 |
+
self.latent_channels_skl = latent_channels_skl
|
| 490 |
+
self.num_res_blocks = num_res_blocks
|
| 491 |
+
self.channels = channels
|
| 492 |
+
self.num_res_blocks_middle = num_res_blocks_middle
|
| 493 |
+
self.norm_type = norm_type
|
| 494 |
+
self.use_fp16 = use_fp16
|
| 495 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 496 |
+
self.encode_global = encode_global
|
| 497 |
+
self.global_token_num = global_token_num
|
| 498 |
+
self.encode_global_skl = encode_global_skl
|
| 499 |
+
self.global_token_num_skl = global_token_num_skl
|
| 500 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 501 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch
|
| 502 |
+
self.normalize_z = normalize_z
|
| 503 |
+
self.normalize_z_skl = normalize_z_skl
|
| 504 |
+
self.normalize_scale = normalize_scale
|
| 505 |
+
|
| 506 |
+
if self.encode_global:
|
| 507 |
+
# Initial Grids and PE
|
| 508 |
+
upsample_factor = 2 ** (len(channels) - 1)
|
| 509 |
+
self.base_size_ss = 64 // upsample_factor
|
| 510 |
+
self.init_grids_ss = nn.Parameter(torch.zeros(1, channels[0], self.base_size_ss**3).permute(0, 2, 1).contiguous().clone()) # 1, N, C
|
| 511 |
+
pos_embedder = AbsolutePositionEmbedder(channels[0], 3)
|
| 512 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size_ss] * 3], indexing='ij')
|
| 513 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 514 |
+
grid_pos_emb = pos_embedder(coords)
|
| 515 |
+
self.register_buffer("grid_pos_emb_ss", grid_pos_emb)
|
| 516 |
+
# Tokens PE
|
| 517 |
+
pos_embedder = AbsolutePositionEmbedder(channels[0], 1)
|
| 518 |
+
coords = torch.arange(global_token_num, device=self.device).reshape(-1, 1)
|
| 519 |
+
tokens_pos_emb = pos_embedder(coords)
|
| 520 |
+
self.register_buffer('tokens_pos_emb_ss', tokens_pos_emb)
|
| 521 |
+
# Token projection layer
|
| 522 |
+
self.token_proj_ss = nn.Linear(channels[0]*2, channels[0])
|
| 523 |
+
|
| 524 |
+
# Input layers
|
| 525 |
+
self.input_layer = nn.ModuleList(
|
| 526 |
+
[TransformerBlock(
|
| 527 |
+
channels=channels[0] if _ != 0 else latent_channels + channels[0],
|
| 528 |
+
out_channels=channels[0],
|
| 529 |
+
num_heads=4 if _ == 0 else 16,
|
| 530 |
+
attn_mode="full",
|
| 531 |
+
qkv_bias=False,
|
| 532 |
+
) for _ in range(4)] +
|
| 533 |
+
[TransformerCrossBlock(
|
| 534 |
+
channels=channels[0],
|
| 535 |
+
ctx_channels=channels[0],
|
| 536 |
+
out_channels=channels[0],
|
| 537 |
+
num_heads=16,
|
| 538 |
+
attn_mode="full",
|
| 539 |
+
qkv_bias=False,
|
| 540 |
+
x_is_query=False)]
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1)
|
| 544 |
+
|
| 545 |
+
if self.encode_global_skl:
|
| 546 |
+
# Initial Grids and PE
|
| 547 |
+
upsample_factor = 2 ** (len(channels) - 1)
|
| 548 |
+
self.base_size = 64 // upsample_factor
|
| 549 |
+
self.init_grids = nn.Parameter(torch.zeros(1, channels[0], self.base_size**3).permute(0, 2, 1).contiguous().clone()) # 1, N, C
|
| 550 |
+
pos_embedder = AbsolutePositionEmbedder(channels[0], 3)
|
| 551 |
+
coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [self.base_size] * 3], indexing='ij')
|
| 552 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
| 553 |
+
grid_pos_emb = pos_embedder(coords)
|
| 554 |
+
self.register_buffer("grid_pos_emb", grid_pos_emb)
|
| 555 |
+
# Tokens PE
|
| 556 |
+
pos_embedder = AbsolutePositionEmbedder(channels[0], 1)
|
| 557 |
+
coords = torch.arange(global_token_num_skl, device=self.device).reshape(-1, 1)
|
| 558 |
+
tokens_pos_emb = pos_embedder(coords)
|
| 559 |
+
self.register_buffer('tokens_pos_emb', tokens_pos_emb)
|
| 560 |
+
# Token projection layer
|
| 561 |
+
self.token_proj = nn.Linear(channels[0]*2, channels[0])
|
| 562 |
+
|
| 563 |
+
# Input layers
|
| 564 |
+
self.input_layer_skl = nn.ModuleList(
|
| 565 |
+
[TransformerBlock(
|
| 566 |
+
channels=channels[0] if _ != 0 else latent_channels_skl + channels[0],
|
| 567 |
+
out_channels=channels[0],
|
| 568 |
+
num_heads=4 if _ == 0 else 16,
|
| 569 |
+
attn_mode="full",
|
| 570 |
+
qkv_bias=False,
|
| 571 |
+
) for _ in range(4)] +
|
| 572 |
+
[TransformerCrossBlock(
|
| 573 |
+
channels=channels[0],
|
| 574 |
+
ctx_channels=channels[0],
|
| 575 |
+
out_channels=channels[0],
|
| 576 |
+
num_heads=16,
|
| 577 |
+
attn_mode="full",
|
| 578 |
+
qkv_bias=False,
|
| 579 |
+
x_is_query=False)]
|
| 580 |
+
)
|
| 581 |
+
else:
|
| 582 |
+
self.input_layer_skl = nn.Conv3d(latent_channels_skl, channels[0], 3, padding=1)
|
| 583 |
+
|
| 584 |
+
self.middle_block = nn.Sequential(*[
|
| 585 |
+
ResBlock3d(channels[0], channels[0])
|
| 586 |
+
for _ in range(num_res_blocks_middle)
|
| 587 |
+
])
|
| 588 |
+
self.middle_block_skl = nn.Sequential(*[
|
| 589 |
+
ResBlock3d(channels[0] if _ == 0 else channels[0], channels[0])
|
| 590 |
+
for _ in range(num_res_blocks_middle)
|
| 591 |
+
])
|
| 592 |
+
|
| 593 |
+
self.blocks = nn.ModuleList([])
|
| 594 |
+
self.blocks_skl = nn.ModuleList([])
|
| 595 |
+
for i, ch in enumerate(channels):
|
| 596 |
+
self.blocks.extend([
|
| 597 |
+
ResBlock3d(ch, ch)
|
| 598 |
+
for _ in range(num_res_blocks)
|
| 599 |
+
])
|
| 600 |
+
if i < len(channels) - 1:
|
| 601 |
+
self.blocks.append(
|
| 602 |
+
UpsampleBlock3d(ch, channels[i+1])
|
| 603 |
+
)
|
| 604 |
+
self.blocks_skl.extend([
|
| 605 |
+
ResBlock3d(ch, ch)
|
| 606 |
+
for _ in range(num_res_blocks)
|
| 607 |
+
])
|
| 608 |
+
if i < len(channels) - 1:
|
| 609 |
+
self.blocks_skl.append(
|
| 610 |
+
UpsampleBlock3d(ch, channels[i+1])
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
self.out_layer = nn.Sequential(
|
| 614 |
+
norm_layer(norm_type, channels[-1]),
|
| 615 |
+
nn.SiLU(),
|
| 616 |
+
nn.Conv3d(channels[-1], self.out_channels, 3, padding=1)
|
| 617 |
+
)
|
| 618 |
+
self.out_layer_skl = nn.Sequential(
|
| 619 |
+
norm_layer(norm_type, channels[-1]),
|
| 620 |
+
nn.SiLU(),
|
| 621 |
+
nn.Conv3d(channels[-1], self.out_channels_skl, 3, padding=1)
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
self.initialize_weights()
|
| 625 |
+
if use_fp16:
|
| 626 |
+
self.convert_to_fp16()
|
| 627 |
+
|
| 628 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 629 |
+
# Freeze: self.input_layer, self.blocks, self.middle_block, self.out_layer
|
| 630 |
+
for module in modules_to_freeze:
|
| 631 |
+
if hasattr(self, module):
|
| 632 |
+
mod = getattr(self, module)
|
| 633 |
+
if isinstance(mod, nn.ModuleList):
|
| 634 |
+
for m in mod:
|
| 635 |
+
for param in m.parameters():
|
| 636 |
+
param.requires_grad = False
|
| 637 |
+
else:
|
| 638 |
+
for param in mod.parameters():
|
| 639 |
+
param.requires_grad = False
|
| 640 |
+
|
| 641 |
+
@property
|
| 642 |
+
def device(self) -> torch.device:
|
| 643 |
+
"""
|
| 644 |
+
Return the device of the model.
|
| 645 |
+
"""
|
| 646 |
+
return next(self.parameters()).device
|
| 647 |
+
|
| 648 |
+
def convert_to_fp16(self) -> None:
|
| 649 |
+
"""
|
| 650 |
+
Convert the torso of the model to float16.
|
| 651 |
+
"""
|
| 652 |
+
self.use_fp16 = True
|
| 653 |
+
self.dtype = torch.float16
|
| 654 |
+
self.blocks.apply(convert_module_to_f16)
|
| 655 |
+
self.middle_block.apply(convert_module_to_f16)
|
| 656 |
+
self.blocks_skl.apply(convert_module_to_f16)
|
| 657 |
+
self.middle_block_skl.apply(convert_module_to_f16)
|
| 658 |
+
if self.encode_global_skl:
|
| 659 |
+
self.token_proj.apply(convert_module_to_f16)
|
| 660 |
+
self.input_layer_skl.apply(convert_module_to_f16)
|
| 661 |
+
if self.encode_global:
|
| 662 |
+
self.token_proj_ss.apply(convert_module_to_f16)
|
| 663 |
+
self.input_layer.apply(convert_module_to_f16)
|
| 664 |
+
|
| 665 |
+
def convert_to_fp32(self) -> None:
|
| 666 |
+
"""
|
| 667 |
+
Convert the torso of the model to float32.
|
| 668 |
+
"""
|
| 669 |
+
self.use_fp16 = False
|
| 670 |
+
self.dtype = torch.float32
|
| 671 |
+
self.blocks.apply(convert_module_to_f32)
|
| 672 |
+
self.middle_block.apply(convert_module_to_f32)
|
| 673 |
+
self.blocks_skl.apply(convert_module_to_f32)
|
| 674 |
+
self.middle_block_skl.apply(convert_module_to_f32)
|
| 675 |
+
if self.encode_global_skl:
|
| 676 |
+
self.token_proj.apply(convert_module_to_f32)
|
| 677 |
+
self.input_layer_skl.apply(convert_module_to_f32)
|
| 678 |
+
if self.encode_global:
|
| 679 |
+
self.token_proj_ss.apply(convert_module_to_f32)
|
| 680 |
+
self.input_layer.apply(convert_module_to_f32)
|
| 681 |
+
|
| 682 |
+
def initialize_weights(self) -> None:
|
| 683 |
+
# Initialize transformer layers:
|
| 684 |
+
def _basic_init(module):
|
| 685 |
+
if isinstance(module, nn.Linear):
|
| 686 |
+
torch.nn.init.kaiming_uniform_(module.weight, nonlinearity='linear')
|
| 687 |
+
if module.bias is not None:
|
| 688 |
+
nn.init.constant_(module.bias, 0)
|
| 689 |
+
self.apply(_basic_init)
|
| 690 |
+
|
| 691 |
+
def forward(self, x: torch.Tensor, x_skl: torch.Tensor) -> torch.Tensor:
|
| 692 |
+
h = F.normalize(x, dim=1) * self.normalize_scale if self.normalize_z else x
|
| 693 |
+
h_skl = F.normalize(x_skl, dim=1) * self.normalize_scale if self.normalize_z_skl else x_skl
|
| 694 |
+
if self.encode_global:
|
| 695 |
+
B, _, _ = h.shape
|
| 696 |
+
h = torch.cat([h, self.tokens_pos_emb_ss[None].expand(B, -1, -1)], dim=-1).type(self.dtype)
|
| 697 |
+
h = h.type(self.dtype)
|
| 698 |
+
for layer in self.input_layer[:-1]:
|
| 699 |
+
h = layer(h)
|
| 700 |
+
init_grids = torch.cat([self.init_grids_ss, self.grid_pos_emb_ss[None].expand_as(self.init_grids_ss)], dim=-1).type(self.dtype)
|
| 701 |
+
grids = self.token_proj_ss(init_grids.expand(B, -1, -1))
|
| 702 |
+
h = self.input_layer[-1](grids, h) # B, N, C
|
| 703 |
+
h = h.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size)
|
| 704 |
+
else:
|
| 705 |
+
h = self.input_layer(h)
|
| 706 |
+
h = h.type(self.dtype)
|
| 707 |
+
if self.encode_global_skl:
|
| 708 |
+
B, _, _ = h_skl.shape
|
| 709 |
+
h_skl = torch.cat([h_skl, self.tokens_pos_emb[None].expand(B, -1, -1)], dim=-1).type(self.dtype)
|
| 710 |
+
h_skl = h_skl.type(self.dtype)
|
| 711 |
+
for layer in self.input_layer_skl[:-1]:
|
| 712 |
+
h_skl = layer(h_skl)
|
| 713 |
+
init_grids = torch.cat([self.init_grids, self.grid_pos_emb[None].expand_as(self.init_grids)], dim=-1).type(self.dtype)
|
| 714 |
+
grids = self.token_proj(init_grids.expand(B, -1, -1))
|
| 715 |
+
h_skl = self.input_layer_skl[-1](grids, h_skl) # B, N, C
|
| 716 |
+
h_skl = h_skl.permute(0, 2, 1).view(B, -1, self.base_size, self.base_size, self.base_size)
|
| 717 |
+
else:
|
| 718 |
+
h_skl = self.input_layer_skl(h_skl)
|
| 719 |
+
h_skl = h_skl.type(self.dtype)
|
| 720 |
+
h_skl = self.middle_block_skl(h_skl)
|
| 721 |
+
h = self.middle_block(h)
|
| 722 |
+
for block, block_skl in zip(self.blocks, self.blocks_skl):
|
| 723 |
+
h_skl = block_skl(h_skl)
|
| 724 |
+
h = block(h)
|
| 725 |
+
h = h.type(x.dtype)
|
| 726 |
+
h = self.out_layer(h)
|
| 727 |
+
h_skl = h_skl.type(x.dtype)
|
| 728 |
+
h_skl = self.out_layer_skl(h_skl)
|
| 729 |
+
return h, h_skl
|
anigen/models/anigen_structured_latent_flow.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from anigen.modules.transformer import blocks
|
| 9 |
+
from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 10 |
+
from ..modules.transformer import AbsolutePositionEmbedder
|
| 11 |
+
from ..modules.norm import LayerNorm32
|
| 12 |
+
from ..modules import sparse as sp
|
| 13 |
+
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
|
| 14 |
+
from .sparse_elastic_mixin import SparseTransformerElasticMixin
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TimestepEmbedder(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Embeds scalar timesteps into vector representations.
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.mlp = nn.Sequential(
|
| 24 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 25 |
+
nn.SiLU(),
|
| 26 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 27 |
+
)
|
| 28 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 32 |
+
"""
|
| 33 |
+
Create sinusoidal timestep embeddings.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
t: a 1-D Tensor of N indices, one per batch element.
|
| 37 |
+
These may be fractional.
|
| 38 |
+
dim: the dimension of the output.
|
| 39 |
+
max_period: controls the minimum frequency of the embeddings.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
an (N, D) Tensor of positional embeddings.
|
| 43 |
+
"""
|
| 44 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 45 |
+
half = dim // 2
|
| 46 |
+
freqs = torch.exp(
|
| 47 |
+
-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 48 |
+
).to(device=t.device)
|
| 49 |
+
args = t[:, None].float() * freqs[None]
|
| 50 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 51 |
+
if dim % 2:
|
| 52 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 53 |
+
return embedding
|
| 54 |
+
|
| 55 |
+
def forward(self, t):
|
| 56 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 57 |
+
t_emb = self.mlp(t_freq)
|
| 58 |
+
return t_emb
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class SparseResBlock3d(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
channels: int,
|
| 65 |
+
emb_channels: int,
|
| 66 |
+
out_channels: Optional[int] = None,
|
| 67 |
+
downsample: bool = False,
|
| 68 |
+
upsample: bool = False,
|
| 69 |
+
):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.channels = channels
|
| 72 |
+
self.emb_channels = emb_channels
|
| 73 |
+
self.out_channels = out_channels or channels
|
| 74 |
+
self.downsample = downsample
|
| 75 |
+
self.upsample = upsample
|
| 76 |
+
|
| 77 |
+
assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
|
| 78 |
+
|
| 79 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
| 80 |
+
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
| 81 |
+
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
|
| 82 |
+
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
|
| 83 |
+
self.emb_layers = nn.Sequential(
|
| 84 |
+
nn.SiLU(),
|
| 85 |
+
nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
|
| 86 |
+
)
|
| 87 |
+
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
|
| 88 |
+
self.updown = None
|
| 89 |
+
if self.downsample:
|
| 90 |
+
self.updown = sp.SparseDownsample(2)
|
| 91 |
+
elif self.upsample:
|
| 92 |
+
self.updown = sp.SparseUpsample(2)
|
| 93 |
+
|
| 94 |
+
def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 95 |
+
if self.updown is not None:
|
| 96 |
+
x = self.updown(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
|
| 100 |
+
emb_out = self.emb_layers(emb).type(x.dtype)
|
| 101 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
| 102 |
+
|
| 103 |
+
x = self._updown(x)
|
| 104 |
+
h = x.replace(self.norm1(x.feats))
|
| 105 |
+
h = h.replace(F.silu(h.feats))
|
| 106 |
+
h = self.conv1(h)
|
| 107 |
+
h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
|
| 108 |
+
h = h.replace(F.silu(h.feats))
|
| 109 |
+
h = self.conv2(h)
|
| 110 |
+
h = h + self.skip_connection(x)
|
| 111 |
+
|
| 112 |
+
return h
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AniGenSLatFlowModel(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
resolution: int,
|
| 119 |
+
in_channels: int,
|
| 120 |
+
in_channels_vert_skin: int,
|
| 121 |
+
in_channels_skl: int,
|
| 122 |
+
model_channels: int,
|
| 123 |
+
model_channels_vert_skin: int,
|
| 124 |
+
model_channels_skl: int,
|
| 125 |
+
cond_channels: int,
|
| 126 |
+
out_channels: int,
|
| 127 |
+
out_channels_vert_skin: int,
|
| 128 |
+
out_channels_skl: int,
|
| 129 |
+
num_blocks: int,
|
| 130 |
+
num_heads: Optional[int] = None,
|
| 131 |
+
num_head_channels: Optional[int] = 64,
|
| 132 |
+
num_heads_vert_skin: Optional[int] = None,
|
| 133 |
+
num_head_channels_vert_skin: Optional[int] = 64,
|
| 134 |
+
num_heads_skl: Optional[int] = None,
|
| 135 |
+
num_head_channels_skl: Optional[int] = 64,
|
| 136 |
+
mlp_ratio: float = 4,
|
| 137 |
+
patch_size: int = 2,
|
| 138 |
+
num_io_res_blocks: int = 2,
|
| 139 |
+
num_io_res_blocks_vert_skin: int = 2,
|
| 140 |
+
num_io_res_blocks_skl: int = 2,
|
| 141 |
+
io_block_channels: List[int] = None,
|
| 142 |
+
io_block_channels_vert_skin: List[int] = None,
|
| 143 |
+
io_block_channels_skl: List[int] = None,
|
| 144 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 145 |
+
use_fp16: bool = False,
|
| 146 |
+
use_checkpoint: bool = False,
|
| 147 |
+
use_skip_connection: bool = True,
|
| 148 |
+
share_mod: bool = False,
|
| 149 |
+
qk_rms_norm: bool = False,
|
| 150 |
+
qk_rms_norm_cross: bool = False,
|
| 151 |
+
use_pretrain_branch: bool = True,
|
| 152 |
+
freeze_pretrain_branch: bool = True,
|
| 153 |
+
modules_to_freeze: Optional[List[str]] = ['blocks', 'input_blocks','input_layer', 'out_blocks', 'out_layer', 't_embedder'],
|
| 154 |
+
predict_x0: bool = False,
|
| 155 |
+
t_eps: float = 5e-2,
|
| 156 |
+
t_scale: float = 1e3,
|
| 157 |
+
use_joint_num_cond: bool = False,
|
| 158 |
+
joint_num_max: int = 60,
|
| 159 |
+
joint_num_fourier_bands: int = 6,
|
| 160 |
+
):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
self.pretrain_class_name = ["AniGenSlatFlowImage"]
|
| 164 |
+
|
| 165 |
+
self.resolution = resolution
|
| 166 |
+
self.in_channels = in_channels
|
| 167 |
+
self.in_channels_vert_skin = in_channels_vert_skin
|
| 168 |
+
self.in_channels_skl = in_channels_skl
|
| 169 |
+
self.model_channels = model_channels
|
| 170 |
+
self.model_channels_vert_skin = model_channels_vert_skin
|
| 171 |
+
self.model_channels_skl = model_channels_skl
|
| 172 |
+
self.cond_channels = cond_channels
|
| 173 |
+
self.out_channels = out_channels
|
| 174 |
+
self.out_channels_vert_skin = out_channels_vert_skin
|
| 175 |
+
self.out_channels_skl = out_channels_skl
|
| 176 |
+
self.num_blocks = num_blocks
|
| 177 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 178 |
+
self.num_heads_vert_skin = num_heads_vert_skin or model_channels_vert_skin // num_head_channels_vert_skin
|
| 179 |
+
self.num_heads_skl = num_heads_skl or model_channels_skl // num_head_channels_skl
|
| 180 |
+
self.mlp_ratio = mlp_ratio
|
| 181 |
+
self.patch_size = patch_size
|
| 182 |
+
self.num_io_res_blocks = num_io_res_blocks
|
| 183 |
+
self.num_io_res_blocks_vert_skin = num_io_res_blocks_vert_skin
|
| 184 |
+
self.num_io_res_blocks_skl = num_io_res_blocks_skl
|
| 185 |
+
self.io_block_channels = io_block_channels
|
| 186 |
+
self.io_block_channels_vert_skin = io_block_channels_vert_skin
|
| 187 |
+
self.io_block_channels_skl = io_block_channels_skl
|
| 188 |
+
self.pe_mode = pe_mode
|
| 189 |
+
self.use_fp16 = use_fp16
|
| 190 |
+
self.use_checkpoint = use_checkpoint
|
| 191 |
+
self.use_skip_connection = use_skip_connection
|
| 192 |
+
self.share_mod = share_mod
|
| 193 |
+
self.qk_rms_norm = qk_rms_norm
|
| 194 |
+
self.qk_rms_norm_cross = qk_rms_norm_cross
|
| 195 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 196 |
+
self.predict_x0 = predict_x0
|
| 197 |
+
self.t_eps = t_eps
|
| 198 |
+
self.t_scale = t_scale
|
| 199 |
+
self.use_joint_num_cond = use_joint_num_cond
|
| 200 |
+
self.joint_num_max = joint_num_max
|
| 201 |
+
self.joint_num_fourier_bands = joint_num_fourier_bands
|
| 202 |
+
|
| 203 |
+
if self.io_block_channels is not None:
|
| 204 |
+
assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
|
| 205 |
+
assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
|
| 206 |
+
|
| 207 |
+
self.t_embedder = TimestepEmbedder(model_channels)
|
| 208 |
+
self.t_embedder_vert_skin = TimestepEmbedder(model_channels_vert_skin)
|
| 209 |
+
self.t_embedder_skl = TimestepEmbedder(model_channels_skl)
|
| 210 |
+
|
| 211 |
+
if self.use_joint_num_cond:
|
| 212 |
+
# Joint-number conditioning (applied to skin + skeleton branches).
|
| 213 |
+
# If joints_num is missing/<=0, use learnable unconditional embeddings.
|
| 214 |
+
self.joint_num_embedder_vert_skin = nn.Sequential(
|
| 215 |
+
nn.Linear(2 * joint_num_fourier_bands, model_channels_vert_skin, bias=True),
|
| 216 |
+
nn.SiLU(),
|
| 217 |
+
nn.Linear(model_channels_vert_skin, model_channels_vert_skin, bias=True),
|
| 218 |
+
)
|
| 219 |
+
self.joint_num_embedder_skl = nn.Sequential(
|
| 220 |
+
nn.Linear(2 * joint_num_fourier_bands, model_channels_skl, bias=True),
|
| 221 |
+
nn.SiLU(),
|
| 222 |
+
nn.Linear(model_channels_skl, model_channels_skl, bias=True),
|
| 223 |
+
)
|
| 224 |
+
self.joint_num_uncond_vert_skin = nn.Parameter(torch.zeros(model_channels_vert_skin))
|
| 225 |
+
self.joint_num_uncond_skl = nn.Parameter(torch.zeros(model_channels_skl))
|
| 226 |
+
if share_mod:
|
| 227 |
+
self.adaLN_modulation = nn.Sequential(
|
| 228 |
+
nn.SiLU(),
|
| 229 |
+
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
| 230 |
+
)
|
| 231 |
+
self.adaLN_modulation_vert_skin = nn.Sequential(
|
| 232 |
+
nn.SiLU(),
|
| 233 |
+
nn.Linear(model_channels_vert_skin, 6 * model_channels_vert_skin, bias=True)
|
| 234 |
+
)
|
| 235 |
+
self.adaLN_modulation_skl = nn.Sequential(
|
| 236 |
+
nn.SiLU(),
|
| 237 |
+
nn.Linear(model_channels_skl, 6 * model_channels_skl, bias=True)
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if pe_mode == "ape":
|
| 241 |
+
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
| 242 |
+
self.pos_embedder_vert_skin = AbsolutePositionEmbedder(model_channels_vert_skin)
|
| 243 |
+
self.pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl)
|
| 244 |
+
|
| 245 |
+
# Causuality in conditioning:
|
| 246 |
+
# Geometry <- Conditioned Image (Cross Attention)
|
| 247 |
+
# Skinning <- Geometry (Adapter Layer) + Skeleton (Cross Attention)
|
| 248 |
+
# Skeleton <- Skinning (Cross Attention)
|
| 249 |
+
causial_cond_channels_dict = {'': cond_channels, '_vert_skin': self.model_channels_skl, '_skl': self.model_channels_vert_skin}
|
| 250 |
+
|
| 251 |
+
for postfix in ['', '_vert_skin', '_skl']:
|
| 252 |
+
# Input blocks
|
| 253 |
+
setattr(self, f'input_layer{postfix}', sp.SparseLinear(
|
| 254 |
+
getattr(self, f'in_channels{postfix}'),
|
| 255 |
+
getattr(self, f'model_channels{postfix}') if getattr(self, f'io_block_channels{postfix}') is None else getattr(self, f'io_block_channels{postfix}')[0]
|
| 256 |
+
))
|
| 257 |
+
|
| 258 |
+
setattr(self, f'input_blocks{postfix}', nn.ModuleList([]))
|
| 259 |
+
io_block_channels = getattr(self, f'io_block_channels{postfix}')
|
| 260 |
+
model_channels = getattr(self, f'model_channels{postfix}')
|
| 261 |
+
num_io_res_blocks = getattr(self, f'num_io_res_blocks{postfix}')
|
| 262 |
+
if io_block_channels is not None:
|
| 263 |
+
for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
|
| 264 |
+
getattr(self, f'input_blocks{postfix}').extend([
|
| 265 |
+
SparseResBlock3d(
|
| 266 |
+
chs,
|
| 267 |
+
model_channels,
|
| 268 |
+
out_channels=chs,
|
| 269 |
+
)
|
| 270 |
+
for _ in range(num_io_res_blocks-1)
|
| 271 |
+
])
|
| 272 |
+
getattr(self, f'input_blocks{postfix}').append(
|
| 273 |
+
SparseResBlock3d(
|
| 274 |
+
chs,
|
| 275 |
+
model_channels,
|
| 276 |
+
out_channels=next_chs,
|
| 277 |
+
downsample=True,
|
| 278 |
+
)
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# Transformer blocks
|
| 282 |
+
cond_channels_block = causial_cond_channels_dict[postfix]
|
| 283 |
+
setattr(self, f'blocks{postfix}', nn.ModuleList([
|
| 284 |
+
ModulatedSparseTransformerCrossBlock(
|
| 285 |
+
getattr(self, f'model_channels{postfix}'),
|
| 286 |
+
cond_channels_block,
|
| 287 |
+
num_heads=getattr(self, f'num_heads{postfix}'),
|
| 288 |
+
mlp_ratio=self.mlp_ratio,
|
| 289 |
+
attn_mode='full',
|
| 290 |
+
use_checkpoint=self.use_checkpoint,
|
| 291 |
+
use_rope=(pe_mode == "rope"),
|
| 292 |
+
share_mod=self.share_mod,
|
| 293 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 294 |
+
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
| 295 |
+
norm_for_context=True,
|
| 296 |
+
)
|
| 297 |
+
for _ in range(num_blocks)
|
| 298 |
+
]))
|
| 299 |
+
|
| 300 |
+
# Output blocks
|
| 301 |
+
setattr(self, f'out_blocks{postfix}', nn.ModuleList([]))
|
| 302 |
+
if io_block_channels is not None:
|
| 303 |
+
for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
|
| 304 |
+
getattr(self, f'out_blocks{postfix}').append(
|
| 305 |
+
SparseResBlock3d(
|
| 306 |
+
prev_chs * 2 if self.use_skip_connection else prev_chs,
|
| 307 |
+
model_channels,
|
| 308 |
+
out_channels=chs,
|
| 309 |
+
upsample=True,
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
getattr(self, f'out_blocks{postfix}').extend([
|
| 313 |
+
SparseResBlock3d(
|
| 314 |
+
chs * 2 if self.use_skip_connection else chs,
|
| 315 |
+
model_channels,
|
| 316 |
+
out_channels=chs,
|
| 317 |
+
)
|
| 318 |
+
for _ in range(num_io_res_blocks-1)
|
| 319 |
+
])
|
| 320 |
+
setattr(self, f'out_layer{postfix}', sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], getattr(self, f'out_channels{postfix}')))
|
| 321 |
+
|
| 322 |
+
self.adapter_geo_to_skin = nn.ModuleList([
|
| 323 |
+
sp.SparseLinear(self.model_channels, self.model_channels_vert_skin) for _ in range(num_blocks)
|
| 324 |
+
])
|
| 325 |
+
|
| 326 |
+
self.initialize_weights()
|
| 327 |
+
if use_fp16:
|
| 328 |
+
self.convert_to_fp16()
|
| 329 |
+
|
| 330 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 331 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch
|
| 332 |
+
# self.is_geometry_branch_frozen = self.use_pretrain_branch and self.freeze_pretrain_branch and all([module in modules_to_freeze for module in ['blocks', 'input_blocks','input_layer', 'out_blocks', 'out_layer', 't_embedder']])
|
| 333 |
+
|
| 334 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 335 |
+
for module in modules_to_freeze:
|
| 336 |
+
if hasattr(self, module):
|
| 337 |
+
mod = getattr(self, module)
|
| 338 |
+
if isinstance(mod, nn.ModuleList):
|
| 339 |
+
for m in mod:
|
| 340 |
+
for param in m.parameters():
|
| 341 |
+
param.requires_grad = False
|
| 342 |
+
else:
|
| 343 |
+
for param in mod.parameters():
|
| 344 |
+
param.requires_grad = False
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def device(self) -> torch.device:
|
| 348 |
+
"""
|
| 349 |
+
Return the device of the model.
|
| 350 |
+
"""
|
| 351 |
+
return next(self.parameters()).device
|
| 352 |
+
|
| 353 |
+
def convert_to_fp16(self) -> None:
|
| 354 |
+
"""
|
| 355 |
+
Convert the torso of the model to float16.
|
| 356 |
+
"""
|
| 357 |
+
for postfix in ['', '_vert_skin', '_skl']:
|
| 358 |
+
getattr(self, f'input_blocks{postfix}').apply(convert_module_to_f16)
|
| 359 |
+
getattr(self, f'blocks{postfix}').apply(convert_module_to_f16)
|
| 360 |
+
getattr(self, f'out_blocks{postfix}').apply(convert_module_to_f16)
|
| 361 |
+
self.adapter_geo_to_skin.apply(convert_module_to_f16)
|
| 362 |
+
if self.use_joint_num_cond:
|
| 363 |
+
self.joint_num_embedder_vert_skin.apply(convert_module_to_f16)
|
| 364 |
+
self.joint_num_embedder_skl.apply(convert_module_to_f16)
|
| 365 |
+
self.joint_num_uncond_vert_skin.data = self.joint_num_uncond_vert_skin.data.half()
|
| 366 |
+
self.joint_num_uncond_skl.data = self.joint_num_uncond_skl.data.half()
|
| 367 |
+
|
| 368 |
+
def convert_to_fp32(self) -> None:
|
| 369 |
+
"""
|
| 370 |
+
Convert the torso of the model to float32.
|
| 371 |
+
"""
|
| 372 |
+
for postfix in ['', '_vert_skin', '_skl']:
|
| 373 |
+
getattr(self, f'input_blocks{postfix}').apply(convert_module_to_f32)
|
| 374 |
+
getattr(self, f'blocks{postfix}').apply(convert_module_to_f32)
|
| 375 |
+
getattr(self, f'out_blocks{postfix}').apply(convert_module_to_f32)
|
| 376 |
+
self.adapter_geo_to_skin.apply(convert_module_to_f32)
|
| 377 |
+
if self.use_joint_num_cond:
|
| 378 |
+
self.joint_num_embedder_vert_skin.apply(convert_module_to_f32)
|
| 379 |
+
self.joint_num_embedder_skl.apply(convert_module_to_f32)
|
| 380 |
+
self.joint_num_uncond_vert_skin.data = self.joint_num_uncond_vert_skin.data.float()
|
| 381 |
+
self.joint_num_uncond_skl.data = self.joint_num_uncond_skl.data.float()
|
| 382 |
+
|
| 383 |
+
def initialize_weights(self) -> None:
|
| 384 |
+
# Initialize transformer layers:
|
| 385 |
+
def _basic_init(module):
|
| 386 |
+
if isinstance(module, nn.Linear):
|
| 387 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 388 |
+
if module.bias is not None:
|
| 389 |
+
nn.init.constant_(module.bias, 0)
|
| 390 |
+
self.apply(_basic_init)
|
| 391 |
+
|
| 392 |
+
for postfix in ['', '_vert_skin', '_skl']:
|
| 393 |
+
nn.init.normal_(getattr(self, f't_embedder{postfix}').mlp[0].weight, std=0.02)
|
| 394 |
+
nn.init.normal_(getattr(self, f't_embedder{postfix}').mlp[2].weight, std=0.02)
|
| 395 |
+
if self.share_mod:
|
| 396 |
+
nn.init.constant_(getattr(self, f'adaLN_modulation{postfix}')[-1].weight, 0)
|
| 397 |
+
nn.init.constant_(getattr(self, f'adaLN_modulation{postfix}')[-1].bias, 0)
|
| 398 |
+
else:
|
| 399 |
+
for block in getattr(self, f'blocks{postfix}'):
|
| 400 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 401 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 402 |
+
nn.init.constant_(getattr(self, f'out_layer{postfix}').weight, 0)
|
| 403 |
+
nn.init.constant_(getattr(self, f'out_layer{postfix}').bias, 0)
|
| 404 |
+
|
| 405 |
+
for layer in self.adapter_geo_to_skin:
|
| 406 |
+
nn.init.constant_(layer.weight, 0)
|
| 407 |
+
nn.init.constant_(layer.bias, 0)
|
| 408 |
+
|
| 409 |
+
if self.use_joint_num_cond:
|
| 410 |
+
# Joint-number conditioning layers
|
| 411 |
+
for emb in [self.joint_num_embedder_vert_skin, self.joint_num_embedder_skl]:
|
| 412 |
+
for m in emb.modules():
|
| 413 |
+
if isinstance(m, nn.Linear):
|
| 414 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 415 |
+
if m.bias is not None:
|
| 416 |
+
nn.init.constant_(m.bias, 0)
|
| 417 |
+
|
| 418 |
+
def _fourier_encode_joint_num(self, joints_num: torch.Tensor) -> torch.Tensor:
|
| 419 |
+
"""Fourier features for joints_num in [0, joint_num_max]."""
|
| 420 |
+
# Keep dtype consistent with model (e.g., fp16) to avoid Linear dtype mismatch.
|
| 421 |
+
dtype = getattr(self, 'dtype', torch.float32)
|
| 422 |
+
x = (joints_num.to(dtype=dtype) / float(self.joint_num_max)).clamp(0.0, 1.0)
|
| 423 |
+
x = x[:, None]
|
| 424 |
+
freqs = (2.0 ** torch.arange(self.joint_num_fourier_bands, device=x.device, dtype=x.dtype)) * math.pi
|
| 425 |
+
angles = x * freqs[None, :]
|
| 426 |
+
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
|
| 427 |
+
|
| 428 |
+
def _get_joint_num_emb(self, joints_num: Optional[torch.Tensor], batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 429 |
+
"""Return (emb_vert_skin, emb_skl), shape [B, C_*]."""
|
| 430 |
+
if joints_num is None:
|
| 431 |
+
joints_num = torch.zeros(batch_size, device=device)
|
| 432 |
+
elif not torch.is_tensor(joints_num):
|
| 433 |
+
joints_num = torch.tensor(joints_num, device=device)
|
| 434 |
+
joints_num = joints_num.to(device=device)
|
| 435 |
+
if joints_num.dim() == 0:
|
| 436 |
+
joints_num = joints_num[None].expand(batch_size)
|
| 437 |
+
joints_num = joints_num.reshape(batch_size)
|
| 438 |
+
|
| 439 |
+
mask_dtype = getattr(self, 'dtype', torch.float32)
|
| 440 |
+
uncond_mask = (joints_num <= 0).to(dtype=mask_dtype, device=device)[:, None]
|
| 441 |
+
joints_num = joints_num.clamp(min=0, max=self.joint_num_max)
|
| 442 |
+
|
| 443 |
+
fourier = self._fourier_encode_joint_num(joints_num)
|
| 444 |
+
emb_vs_cond = self.joint_num_embedder_vert_skin(fourier)
|
| 445 |
+
emb_skl_cond = self.joint_num_embedder_skl(fourier)
|
| 446 |
+
|
| 447 |
+
emb_vs_uncond = self.joint_num_uncond_vert_skin[None].expand(batch_size, -1)
|
| 448 |
+
emb_skl_uncond = self.joint_num_uncond_skl[None].expand(batch_size, -1)
|
| 449 |
+
|
| 450 |
+
# Blend: uncond_mask==1 -> unconditional, uncond_mask==0 -> conditional.
|
| 451 |
+
emb_vs = emb_vs_cond * (1.0 - uncond_mask) + emb_vs_uncond * uncond_mask
|
| 452 |
+
emb_skl = emb_skl_cond * (1.0 - uncond_mask) + emb_skl_uncond * uncond_mask
|
| 453 |
+
return emb_vs, emb_skl
|
| 454 |
+
|
| 455 |
+
def forward_stage(
|
| 456 |
+
self,
|
| 457 |
+
x: sp.SparseTensor,
|
| 458 |
+
t: torch.Tensor,
|
| 459 |
+
postfix,
|
| 460 |
+
stage,
|
| 461 |
+
cond_emb: Optional[torch.Tensor] = None,
|
| 462 |
+
t_emb=None,
|
| 463 |
+
skips=None,
|
| 464 |
+
original_dtype=None,
|
| 465 |
+
) -> sp.SparseTensor:
|
| 466 |
+
input_layer = getattr(self, f'input_layer{postfix}')
|
| 467 |
+
t_embedder = getattr(self, f't_embedder{postfix}')
|
| 468 |
+
input_blocks = getattr(self, f'input_blocks{postfix}')
|
| 469 |
+
pos_embedder = getattr(self, f'pos_embedder{postfix}')
|
| 470 |
+
out_blocks = getattr(self, f'out_blocks{postfix}')
|
| 471 |
+
out_layer = getattr(self, f'out_layer{postfix}')
|
| 472 |
+
adaLN_modulation = getattr(self, f'adaLN_modulation{postfix}') if self.share_mod else None
|
| 473 |
+
|
| 474 |
+
if stage == 'in':
|
| 475 |
+
h = input_layer(x).type(self.dtype)
|
| 476 |
+
t_emb = t_embedder(t)
|
| 477 |
+
if cond_emb is not None:
|
| 478 |
+
t_emb = t_emb + cond_emb
|
| 479 |
+
t_emb = t_emb.type(self.dtype)
|
| 480 |
+
t_mod = adaLN_modulation(t_emb).type(self.dtype) if self.share_mod else t_emb
|
| 481 |
+
skips = []
|
| 482 |
+
# pack with input blocks
|
| 483 |
+
for block in input_blocks:
|
| 484 |
+
h = block(h, t_emb)
|
| 485 |
+
skips.append(h.feats)
|
| 486 |
+
if self.pe_mode == "ape":
|
| 487 |
+
h = h + pos_embedder(h.coords[:, 1:]).type(self.dtype)
|
| 488 |
+
return h, t_emb, t_mod, skips
|
| 489 |
+
elif stage == 'out':
|
| 490 |
+
h = x
|
| 491 |
+
# unpack with output blocks
|
| 492 |
+
for block, skip in zip(out_blocks, reversed(skips)):
|
| 493 |
+
if self.use_skip_connection:
|
| 494 |
+
h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
|
| 495 |
+
else:
|
| 496 |
+
h = block(h, t_emb)
|
| 497 |
+
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
| 498 |
+
h = out_layer(h.type(original_dtype))
|
| 499 |
+
return h
|
| 500 |
+
else:
|
| 501 |
+
raise ValueError(f"Unknown stage: {stage}")
|
| 502 |
+
|
| 503 |
+
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor, joints_num: Optional[torch.Tensor] = None, **kwargs) -> sp.SparseTensor:
|
| 504 |
+
cond = cond.type(self.dtype)
|
| 505 |
+
feats, feats_vert_skin = x.feats[:, :self.in_channels], x.feats[:, self.in_channels:]
|
| 506 |
+
x, x_vert_skin = x.replace(feats), x.replace(feats_vert_skin)
|
| 507 |
+
if self.predict_x0:
|
| 508 |
+
xt_feats_skin, xt_feats_skl = feats_vert_skin.clone(), x_skl.feats.clone()
|
| 509 |
+
|
| 510 |
+
joint_emb_vs, joint_emb_skl = None, None
|
| 511 |
+
if self.use_joint_num_cond:
|
| 512 |
+
# joint-number conditioning for skin + skeleton
|
| 513 |
+
joint_emb_vs, joint_emb_skl = self._get_joint_num_emb(joints_num, x.shape[0], x.device)
|
| 514 |
+
joint_emb_vs = joint_emb_vs.type(self.dtype)
|
| 515 |
+
joint_emb_skl = joint_emb_skl.type(self.dtype)
|
| 516 |
+
|
| 517 |
+
in_dicts = {'': x, '_vert_skin': x_vert_skin, '_skl': x_skl}
|
| 518 |
+
cond_emb_dicts = {'': None, '_vert_skin': joint_emb_vs, '_skl': joint_emb_skl}
|
| 519 |
+
postfix_keys = list(in_dicts.keys())
|
| 520 |
+
for postfix in postfix_keys:
|
| 521 |
+
cond_emb = cond_emb_dicts[postfix]
|
| 522 |
+
in_dicts[postfix], in_dicts[f't_emb{postfix}'], in_dicts[f't_mod{postfix}'], in_dicts[f'skips{postfix}'] = self.forward_stage(in_dicts[postfix], t, postfix, stage='in', cond_emb=cond_emb)
|
| 523 |
+
for block, block_skin, block_skl, adapter in zip(self.blocks, self.blocks_vert_skin, self.blocks_skl, self.adapter_geo_to_skin):
|
| 524 |
+
h, h_skin, h_skl = in_dicts[''], in_dicts['_vert_skin'], in_dicts['_skl']
|
| 525 |
+
f = block(h, in_dicts['t_mod'], cond)
|
| 526 |
+
f_skin = block_skin(h_skin, in_dicts['t_mod_vert_skin'], h_skl) + adapter(h)
|
| 527 |
+
f_skl = block_skl(h_skl, in_dicts['t_mod_skl'], h_skin)
|
| 528 |
+
in_dicts[''], in_dicts['_vert_skin'], in_dicts['_skl'] = f, f_skin, f_skl
|
| 529 |
+
for postfix in postfix_keys:
|
| 530 |
+
in_dicts[postfix] = self.forward_stage(
|
| 531 |
+
in_dicts[postfix],
|
| 532 |
+
t,
|
| 533 |
+
postfix,
|
| 534 |
+
stage='out',
|
| 535 |
+
t_emb=in_dicts[f't_emb{postfix}'],
|
| 536 |
+
skips=in_dicts[f'skips{postfix}'],
|
| 537 |
+
original_dtype=x.dtype,
|
| 538 |
+
)
|
| 539 |
+
if self.predict_x0:
|
| 540 |
+
t_normalized = t / self.t_scale
|
| 541 |
+
factor = (1 / t_normalized.clamp_min(self.t_eps))[:, None]
|
| 542 |
+
in_dicts['_vert_skin'] = in_dicts['_vert_skin'].replace((in_dicts['_vert_skin'].feats - xt_feats_skin) * factor[in_dicts['_vert_skin'].coords[:, 0]])
|
| 543 |
+
in_dicts['_skl'] = in_dicts['_skl'].replace((in_dicts['_skl'].feats - xt_feats_skl) * factor[in_dicts['_skl'].coords[:, 0]])
|
| 544 |
+
x_out = x.replace(torch.cat([in_dicts[''].feats, in_dicts['_vert_skin'].feats], dim=1))
|
| 545 |
+
x_skl_out = x_skl.replace(in_dicts['_skl'].feats)
|
| 546 |
+
return x_out, x_skl_out
|
| 547 |
+
|
| 548 |
+
class AniGenElasticSLatFlowModel(SparseTransformerElasticMixin, AniGenSLatFlowModel):
|
| 549 |
+
"""
|
| 550 |
+
SLat Flow Model with elastic memory management.
|
| 551 |
+
Used for training with low VRAM.
|
| 552 |
+
"""
|
| 553 |
+
pass
|
anigen/models/sparse_elastic_mixin.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from contextlib import contextmanager
|
| 2 |
+
from typing import *
|
| 3 |
+
import math
|
| 4 |
+
from ..modules import sparse as sp
|
| 5 |
+
from ..utils.elastic_utils import ElasticModuleMixin
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SparseTransformerElasticMixin(ElasticModuleMixin):
|
| 9 |
+
def _get_input_size(self, x: sp.SparseTensor, *args, **kwargs):
|
| 10 |
+
return x.feats.shape[0]
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def with_mem_ratio(self, mem_ratio=1.0):
|
| 14 |
+
if mem_ratio == 1.0:
|
| 15 |
+
yield 1.0
|
| 16 |
+
return
|
| 17 |
+
num_blocks = len(self.blocks)
|
| 18 |
+
num_checkpoint_blocks = min(math.ceil((1 - mem_ratio) * num_blocks) + 1, num_blocks)
|
| 19 |
+
exact_mem_ratio = 1 - (num_checkpoint_blocks - 1) / num_blocks
|
| 20 |
+
for i in range(num_blocks):
|
| 21 |
+
self.blocks[i].use_checkpoint = i < num_checkpoint_blocks
|
| 22 |
+
yield exact_mem_ratio
|
| 23 |
+
for i in range(num_blocks):
|
| 24 |
+
self.blocks[i].use_checkpoint = False
|
anigen/models/structured_latent_vae/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .anigen_encoder import AniGenElasticSLatEncoder
|
| 2 |
+
from .anigen_decoder import AniGenElasticSLatMeshDecoder
|
| 3 |
+
from .skin_models import SkinAutoEncoder
|
anigen/models/structured_latent_vae/anigen_base.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ...modules import sparse as sp
|
| 5 |
+
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 6 |
+
from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock
|
| 7 |
+
from ...modules.transformer import AbsolutePositionEmbedder, TransformerCrossBlock
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FreqPositionalEmbedder(nn.Module):
|
| 11 |
+
def __init__(self, in_dim, include_input=True, max_freq_log2=8, num_freqs=8, log_sampling=True, periodic_fns=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.in_dim = in_dim
|
| 14 |
+
self.out_dim = None
|
| 15 |
+
self.include_input = include_input
|
| 16 |
+
self.max_freq_log2 = max_freq_log2
|
| 17 |
+
self.num_freqs = num_freqs
|
| 18 |
+
self.log_sampling = log_sampling
|
| 19 |
+
self.periodic_fns = periodic_fns if periodic_fns is not None else [
|
| 20 |
+
torch.sin, torch.cos
|
| 21 |
+
]
|
| 22 |
+
self.create_embedding_fn()
|
| 23 |
+
|
| 24 |
+
def create_embedding_fn(self):
|
| 25 |
+
embed_fns = []
|
| 26 |
+
d = self.in_dim
|
| 27 |
+
out_dim = 0
|
| 28 |
+
if self.include_input:
|
| 29 |
+
embed_fns.append(lambda x: x)
|
| 30 |
+
out_dim += d
|
| 31 |
+
|
| 32 |
+
max_freq = self.max_freq_log2
|
| 33 |
+
N_freqs = self.num_freqs
|
| 34 |
+
|
| 35 |
+
if self.log_sampling:
|
| 36 |
+
freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
|
| 37 |
+
else:
|
| 38 |
+
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
|
| 39 |
+
|
| 40 |
+
for freq in freq_bands:
|
| 41 |
+
for p_fn in self.periodic_fns:
|
| 42 |
+
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
| 43 |
+
out_dim += d
|
| 44 |
+
|
| 45 |
+
self.embed_fns = embed_fns
|
| 46 |
+
self.out_dim = out_dim
|
| 47 |
+
|
| 48 |
+
def forward(self, inputs):
|
| 49 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def block_attn_config(self, attn_mode_attr='attn_mode'):
|
| 53 |
+
"""
|
| 54 |
+
Return the attention configuration of the model.
|
| 55 |
+
"""
|
| 56 |
+
attn_mode = getattr(self, attn_mode_attr)
|
| 57 |
+
for i in range(self.num_blocks):
|
| 58 |
+
if attn_mode == "shift_window":
|
| 59 |
+
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
| 60 |
+
elif attn_mode == "shift_sequence":
|
| 61 |
+
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
| 62 |
+
elif attn_mode == "shift_order":
|
| 63 |
+
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
| 64 |
+
elif attn_mode == "full":
|
| 65 |
+
yield "full", None, None, None, None
|
| 66 |
+
elif attn_mode == "swin":
|
| 67 |
+
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class AniGenSparseTransformerBase(nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Sparse Transformer without output layers.
|
| 73 |
+
Serve as the base class for encoder and decoder.
|
| 74 |
+
"""
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
in_channels: int,
|
| 78 |
+
in_channels_skl: int,
|
| 79 |
+
in_channels_skin: int,
|
| 80 |
+
model_channels: int,
|
| 81 |
+
model_channels_skl: int,
|
| 82 |
+
model_channels_skin: int,
|
| 83 |
+
num_blocks: int,
|
| 84 |
+
num_heads: Optional[int] = None,
|
| 85 |
+
num_heads_skl: int = 8,
|
| 86 |
+
num_heads_skin: int = 8,
|
| 87 |
+
num_head_channels: Optional[int] = 64,
|
| 88 |
+
mlp_ratio: float = 4.0,
|
| 89 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 90 |
+
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
|
| 91 |
+
window_size: Optional[int] = None,
|
| 92 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 93 |
+
use_fp16: bool = False,
|
| 94 |
+
use_checkpoint: bool = False,
|
| 95 |
+
qk_rms_norm: bool = False,
|
| 96 |
+
|
| 97 |
+
skin_cross_from_geo: bool = True,
|
| 98 |
+
skl_cross_from_geo: bool = True,
|
| 99 |
+
skin_skl_cross: bool = True,
|
| 100 |
+
):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.in_channels = in_channels
|
| 103 |
+
self.in_channels_skl = in_channels_skl
|
| 104 |
+
self.in_channels_skin = in_channels_skin
|
| 105 |
+
self.model_channels = model_channels
|
| 106 |
+
self.model_channels_skl = model_channels_skl
|
| 107 |
+
self.model_channels_skin = model_channels_skin
|
| 108 |
+
self.num_blocks = num_blocks
|
| 109 |
+
self.window_size = window_size
|
| 110 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 111 |
+
self.mlp_ratio = mlp_ratio
|
| 112 |
+
self.attn_mode = attn_mode
|
| 113 |
+
self.attn_mode_cross = attn_mode_cross
|
| 114 |
+
self.pe_mode = pe_mode
|
| 115 |
+
self.use_fp16 = use_fp16
|
| 116 |
+
self.use_checkpoint = use_checkpoint
|
| 117 |
+
self.qk_rms_norm = qk_rms_norm
|
| 118 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 119 |
+
self.skin_cross_from_geo = skin_cross_from_geo
|
| 120 |
+
self.skl_cross_from_geo = skl_cross_from_geo
|
| 121 |
+
self.skin_skl_cross = skin_skl_cross
|
| 122 |
+
|
| 123 |
+
if pe_mode == "ape":
|
| 124 |
+
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
| 125 |
+
self.pos_embedder_skl = AbsolutePositionEmbedder(model_channels_skl)
|
| 126 |
+
self.pos_embedder_skin = AbsolutePositionEmbedder(model_channels_skin)
|
| 127 |
+
|
| 128 |
+
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
| 129 |
+
self.input_layer_skl = sp.SparseLinear(in_channels_skl, model_channels_skl)
|
| 130 |
+
self.input_layer_skin = sp.SparseLinear(in_channels_skin, model_channels_skin)
|
| 131 |
+
|
| 132 |
+
self.blocks = nn.ModuleList([
|
| 133 |
+
SparseTransformerBlock(
|
| 134 |
+
model_channels,
|
| 135 |
+
num_heads=self.num_heads,
|
| 136 |
+
mlp_ratio=self.mlp_ratio,
|
| 137 |
+
attn_mode=attn_mode,
|
| 138 |
+
window_size=window_size,
|
| 139 |
+
shift_sequence=shift_sequence,
|
| 140 |
+
shift_window=shift_window,
|
| 141 |
+
serialize_mode=serialize_mode,
|
| 142 |
+
use_checkpoint=self.use_checkpoint,
|
| 143 |
+
use_rope=(pe_mode == "rope"),
|
| 144 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 145 |
+
)
|
| 146 |
+
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
| 147 |
+
])
|
| 148 |
+
|
| 149 |
+
ctx_channels = []
|
| 150 |
+
if skin_skl_cross:
|
| 151 |
+
ctx_channels.append(model_channels_skl)
|
| 152 |
+
if skin_cross_from_geo:
|
| 153 |
+
ctx_channels.append(model_channels)
|
| 154 |
+
self.blocks_skin = nn.ModuleList([
|
| 155 |
+
SparseTransformerMultiContextCrossBlock(
|
| 156 |
+
model_channels_skin,
|
| 157 |
+
ctx_channels=ctx_channels,
|
| 158 |
+
num_heads=num_heads_skin,
|
| 159 |
+
mlp_ratio=self.mlp_ratio,
|
| 160 |
+
attn_mode=attn_mode,
|
| 161 |
+
attn_mode_cross=attn_mode,
|
| 162 |
+
window_size=window_size,
|
| 163 |
+
shift_sequence=shift_sequence,
|
| 164 |
+
shift_window=shift_window,
|
| 165 |
+
serialize_mode=serialize_mode,
|
| 166 |
+
use_checkpoint=self.use_checkpoint,
|
| 167 |
+
use_rope=(pe_mode == "rope"),
|
| 168 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 169 |
+
cross_attn_cache_suffix='_skin',
|
| 170 |
+
)
|
| 171 |
+
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self, "attn_mode_cross")
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
ctx_channels = []
|
| 175 |
+
if skin_skl_cross:
|
| 176 |
+
ctx_channels.append(model_channels_skin)
|
| 177 |
+
if skl_cross_from_geo:
|
| 178 |
+
ctx_channels.append(model_channels)
|
| 179 |
+
self.blocks_skl = nn.ModuleList([
|
| 180 |
+
SparseTransformerMultiContextCrossBlock(
|
| 181 |
+
model_channels_skl,
|
| 182 |
+
ctx_channels=ctx_channels,
|
| 183 |
+
num_heads=num_heads_skl,
|
| 184 |
+
mlp_ratio=self.mlp_ratio,
|
| 185 |
+
attn_mode=attn_mode,
|
| 186 |
+
attn_mode_cross=attn_mode,
|
| 187 |
+
window_size=window_size,
|
| 188 |
+
shift_sequence=shift_sequence,
|
| 189 |
+
shift_window=shift_window,
|
| 190 |
+
serialize_mode=serialize_mode,
|
| 191 |
+
use_checkpoint=self.use_checkpoint,
|
| 192 |
+
use_rope=(pe_mode == "rope"),
|
| 193 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 194 |
+
cross_attn_cache_suffix='_skl',
|
| 195 |
+
)
|
| 196 |
+
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self, "attn_mode_cross")
|
| 197 |
+
])
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def device(self) -> torch.device:
|
| 201 |
+
"""
|
| 202 |
+
Return the device of the model.
|
| 203 |
+
"""
|
| 204 |
+
return next(self.parameters()).device
|
| 205 |
+
|
| 206 |
+
def convert_to_fp16(self) -> None:
|
| 207 |
+
"""
|
| 208 |
+
Convert the torso of the model to float16.
|
| 209 |
+
"""
|
| 210 |
+
self.blocks.apply(convert_module_to_f16)
|
| 211 |
+
self.blocks_skl.apply(convert_module_to_f16)
|
| 212 |
+
self.blocks_skin.apply(convert_module_to_f16)
|
| 213 |
+
|
| 214 |
+
def convert_to_fp32(self) -> None:
|
| 215 |
+
"""
|
| 216 |
+
Convert the torso of the model to float32.
|
| 217 |
+
"""
|
| 218 |
+
self.blocks.apply(convert_module_to_f32)
|
| 219 |
+
self.blocks_skl.apply(convert_module_to_f32)
|
| 220 |
+
self.blocks_skin.apply(convert_module_to_f32)
|
| 221 |
+
|
| 222 |
+
def initialize_weights(self) -> None:
|
| 223 |
+
# Initialize transformer layers:
|
| 224 |
+
def _basic_init(module):
|
| 225 |
+
if isinstance(module, nn.Linear):
|
| 226 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 227 |
+
if module.bias is not None:
|
| 228 |
+
nn.init.constant_(module.bias, 0)
|
| 229 |
+
self.apply(_basic_init)
|
| 230 |
+
|
| 231 |
+
def forward_input_layer(self, x: sp.SparseTensor, layer, pos_embedder) -> sp.SparseTensor:
|
| 232 |
+
h = layer(x)
|
| 233 |
+
if self.pe_mode == "ape":
|
| 234 |
+
h = h + pos_embedder(x.coords[:, 1:])
|
| 235 |
+
h = h.type(self.dtype)
|
| 236 |
+
return h
|
| 237 |
+
|
| 238 |
+
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, x_skin: sp.SparseTensor) -> sp.SparseTensor:
|
| 239 |
+
h = self.forward_input_layer(x, self.input_layer, self.pos_embedder)
|
| 240 |
+
h_skl = self.forward_input_layer(x_skl, self.input_layer_skl, self.pos_embedder_skl)
|
| 241 |
+
h_skin = self.forward_input_layer(x_skin, self.input_layer_skin, self.pos_embedder_skin)
|
| 242 |
+
|
| 243 |
+
for block, block_skl, block_skin in zip(self.blocks, self.blocks_skl, self.blocks_skin):
|
| 244 |
+
f, f_skl, f_skin = h, h_skl, h_skin
|
| 245 |
+
h = block(f)
|
| 246 |
+
skl_contexts, skin_contexts = [], []
|
| 247 |
+
if self.skin_skl_cross:
|
| 248 |
+
skl_contexts.append(f_skin)
|
| 249 |
+
skin_contexts.append(f_skl)
|
| 250 |
+
if self.skl_cross_from_geo:
|
| 251 |
+
skl_contexts.append(f)
|
| 252 |
+
if self.skin_cross_from_geo:
|
| 253 |
+
skin_contexts.append(f)
|
| 254 |
+
h_skl = block_skl(f_skl, skl_contexts)
|
| 255 |
+
h_skin = block_skin(f_skin, skin_contexts)
|
| 256 |
+
return h, h_skl, h_skin
|
anigen/models/structured_latent_vae/anigen_decoder.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import contextlib
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from ...modules.sparse.transformer import SparseTransformerMultiContextCrossBlock, SparseTransformerBlock
|
| 6 |
+
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 7 |
+
from ...modules import sparse as sp
|
| 8 |
+
from ...representations import MeshExtractResult
|
| 9 |
+
from ...representations.mesh import AniGenSparseFeatures2Mesh, AniGenSklFeatures2Skeleton
|
| 10 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
| 11 |
+
from pytorch3d.ops import knn_points
|
| 12 |
+
from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
|
| 13 |
+
from .skin_models import SKIN_MODEL_DICT
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from ...representations.skeleton.grouping import GROUPING_STRATEGIES
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SparseSubdivideBlock3d(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
A 3D subdivide block that can subdivide the sparse tensor.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
channels: channels in the inputs and outputs.
|
| 24 |
+
out_channels: if specified, the number of output channels.
|
| 25 |
+
num_groups: the number of groups for the group norm.
|
| 26 |
+
"""
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
channels: int,
|
| 30 |
+
resolution: int,
|
| 31 |
+
out_channels: Optional[int] = None,
|
| 32 |
+
num_groups: int = 32,
|
| 33 |
+
sub_divide: bool = True,
|
| 34 |
+
conv_as_residual: bool = False,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.channels = channels
|
| 38 |
+
self.resolution = resolution
|
| 39 |
+
self.out_resolution = resolution * 2 if sub_divide else resolution
|
| 40 |
+
self.out_channels = out_channels or channels
|
| 41 |
+
self.sub_divide = sub_divide
|
| 42 |
+
self.conv_as_residual = conv_as_residual
|
| 43 |
+
|
| 44 |
+
self.act_layers = nn.Sequential(
|
| 45 |
+
sp.SparseGroupNorm32(num_groups, channels),
|
| 46 |
+
sp.SparseSiLU()
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.sub = sp.SparseSubdivide() if sub_divide else nn.Identity()
|
| 50 |
+
|
| 51 |
+
self.out_layers = nn.Sequential(
|
| 52 |
+
sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
|
| 53 |
+
sp.SparseGroupNorm32(num_groups, self.out_channels),
|
| 54 |
+
sp.SparseSiLU(),
|
| 55 |
+
zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if self.out_channels == channels and not self.conv_as_residual:
|
| 59 |
+
self.skip_connection = nn.Identity()
|
| 60 |
+
else:
|
| 61 |
+
self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
|
| 62 |
+
|
| 63 |
+
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 64 |
+
"""
|
| 65 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 66 |
+
SparseConv3d
|
| 67 |
+
Args:
|
| 68 |
+
x: an [N x C x ...] Tensor of features.
|
| 69 |
+
Returns:
|
| 70 |
+
an [N x C x ...] Tensor of outputs.
|
| 71 |
+
"""
|
| 72 |
+
h = self.act_layers(x)
|
| 73 |
+
h = self.sub(h)
|
| 74 |
+
x = self.sub(x)
|
| 75 |
+
h = self.out_layers(h)
|
| 76 |
+
h = h + self.skip_connection(x)
|
| 77 |
+
return h
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class SparseDownsampleWithCache(nn.Module):
|
| 81 |
+
"""SparseDownsample that stores upsample caches under a unique suffix.
|
| 82 |
+
|
| 83 |
+
This avoids cache-key collisions when stacking multiple down/up stages.
|
| 84 |
+
"""
|
| 85 |
+
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
|
| 88 |
+
self.cache_suffix = cache_suffix
|
| 89 |
+
self._down = sp.SparseDownsample(self.factor)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 92 |
+
out = self._down(x)
|
| 93 |
+
|
| 94 |
+
dim = out.coords.shape[-1] - 1
|
| 95 |
+
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
|
| 96 |
+
k_coords = f'upsample_{factor}_coords'
|
| 97 |
+
k_layout = f'upsample_{factor}_layout'
|
| 98 |
+
k_idx = f'upsample_{factor}_idx'
|
| 99 |
+
coords = out.get_spatial_cache(k_coords)
|
| 100 |
+
layout = out.get_spatial_cache(k_layout)
|
| 101 |
+
idx = out.get_spatial_cache(k_idx)
|
| 102 |
+
if any(v is None for v in [coords, layout, idx]):
|
| 103 |
+
raise ValueError('Downsample cache not found after SparseDownsample.')
|
| 104 |
+
|
| 105 |
+
# spconv expects int32 indices; SparseDownsample produces int64 coords.
|
| 106 |
+
if out.coords.dtype != torch.int32:
|
| 107 |
+
out = sp.SparseTensor(
|
| 108 |
+
out.feats,
|
| 109 |
+
out.coords.to(torch.int32),
|
| 110 |
+
out.shape,
|
| 111 |
+
out.layout,
|
| 112 |
+
scale=out._scale,
|
| 113 |
+
spatial_cache=out._spatial_cache,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords', coords)
|
| 117 |
+
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout', layout)
|
| 118 |
+
out.register_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx', idx)
|
| 119 |
+
# Remove unsuffixed keys to prevent later stages overwriting them.
|
| 120 |
+
try:
|
| 121 |
+
del out._spatial_cache[k_coords]
|
| 122 |
+
del out._spatial_cache[k_layout]
|
| 123 |
+
del out._spatial_cache[k_idx]
|
| 124 |
+
except Exception:
|
| 125 |
+
pass
|
| 126 |
+
return out
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SparseUpsampleWithCache(nn.Module):
|
| 130 |
+
"""SparseUpsample that reads upsample caches under a unique suffix."""
|
| 131 |
+
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]], cache_suffix: str):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
|
| 134 |
+
self.cache_suffix = cache_suffix
|
| 135 |
+
|
| 136 |
+
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 137 |
+
dim = x.coords.shape[-1] - 1
|
| 138 |
+
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * dim
|
| 139 |
+
new_coords = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_coords')
|
| 140 |
+
new_layout = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_layout')
|
| 141 |
+
idx = x.get_spatial_cache(f'upsample_{factor}_{self.cache_suffix}_idx')
|
| 142 |
+
if any(v is None for v in [new_coords, new_layout, idx]):
|
| 143 |
+
raise ValueError('Upsample cache not found. Must be paired with SparseDownsampleWithCache.')
|
| 144 |
+
if new_coords.dtype != torch.int32:
|
| 145 |
+
new_coords = new_coords.to(torch.int32)
|
| 146 |
+
new_feats = x.feats[idx]
|
| 147 |
+
out = sp.SparseTensor(new_feats, new_coords, x.shape, new_layout)
|
| 148 |
+
out._scale = tuple([s * f for s, f in zip(x._scale, factor)])
|
| 149 |
+
out._spatial_cache = x._spatial_cache
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class SparseSkinUNetNLevel(nn.Module):
|
| 154 |
+
"""A simple N-down/N-up sparse UNet for local smoothing.
|
| 155 |
+
|
| 156 |
+
Note: `SparseSubdivideBlock3d` uses `resolution` only to name spconv `indice_key`s.
|
| 157 |
+
We must provide distinct (and stage-appropriate) values per hierarchy to avoid
|
| 158 |
+
rulebook collisions across different coordinate sets.
|
| 159 |
+
"""
|
| 160 |
+
def __init__(self, channels: int, base_resolution: int, num_groups: int = 32, num_levels: int = 3):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
if num_levels < 1:
|
| 164 |
+
raise ValueError(f"num_levels must be >= 1, got {num_levels}")
|
| 165 |
+
self.channels = channels
|
| 166 |
+
self.base_resolution = int(base_resolution)
|
| 167 |
+
self.num_groups = num_groups
|
| 168 |
+
self.num_levels = int(num_levels)
|
| 169 |
+
|
| 170 |
+
def res_block(resolution: int):
|
| 171 |
+
return SparseSubdivideBlock3d(
|
| 172 |
+
channels=channels,
|
| 173 |
+
resolution=resolution,
|
| 174 |
+
out_channels=channels,
|
| 175 |
+
sub_divide=False,
|
| 176 |
+
conv_as_residual=True,
|
| 177 |
+
num_groups=num_groups,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# resolutions[i] corresponds to the i-th encoder stage (before downsample)
|
| 181 |
+
resolutions: List[int] = [max(1, self.base_resolution // (2 ** i)) for i in range(self.num_levels)]
|
| 182 |
+
bottom_resolution = max(1, self.base_resolution // (2 ** self.num_levels))
|
| 183 |
+
|
| 184 |
+
self.enc = nn.ModuleList([res_block(r) for r in resolutions])
|
| 185 |
+
self.down = nn.ModuleList([SparseDownsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
|
| 186 |
+
self.mid = res_block(bottom_resolution)
|
| 187 |
+
|
| 188 |
+
# Decoder blocks operate at the same resolutions as encoder blocks.
|
| 189 |
+
self.up = nn.ModuleList([SparseUpsampleWithCache(2, f'unet{i}') for i in range(self.num_levels)])
|
| 190 |
+
self.fuse = nn.ModuleList([sp.SparseLinear(channels * 2, channels) for _ in range(self.num_levels)])
|
| 191 |
+
self.dec = nn.ModuleList([res_block(r) for r in resolutions])
|
| 192 |
+
|
| 193 |
+
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 194 |
+
in_dtype = x.feats.dtype
|
| 195 |
+
if x.coords.dtype != torch.int32:
|
| 196 |
+
x = sp.SparseTensor(
|
| 197 |
+
x.feats,
|
| 198 |
+
x.coords.to(torch.int32),
|
| 199 |
+
x.shape,
|
| 200 |
+
x.layout,
|
| 201 |
+
scale=x._scale,
|
| 202 |
+
spatial_cache=x._spatial_cache,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# spconv implicit_gemm has a runtime tuner that can fail for some sparse
|
| 206 |
+
# rulebooks under AMP + fp16/bf16. Running UNet convs in fp32 avoids that.
|
| 207 |
+
if hasattr(torch, 'autocast'):
|
| 208 |
+
autocast_ctx = torch.autocast(device_type=x.device.type, enabled=False)
|
| 209 |
+
else:
|
| 210 |
+
# Older torch fallback
|
| 211 |
+
autocast_ctx = torch.cuda.amp.autocast(enabled=False) if x.device.type == 'cuda' else contextlib.nullcontext()
|
| 212 |
+
|
| 213 |
+
with autocast_ctx:
|
| 214 |
+
x_fp32 = x if x.feats.dtype == torch.float32 else x.replace(x.feats.float())
|
| 215 |
+
|
| 216 |
+
skips: List[sp.SparseTensor] = []
|
| 217 |
+
h = x_fp32
|
| 218 |
+
for i in range(self.num_levels):
|
| 219 |
+
s = self.enc[i](h)
|
| 220 |
+
skips.append(s)
|
| 221 |
+
h = self.down[i](s)
|
| 222 |
+
|
| 223 |
+
h = self.mid(h)
|
| 224 |
+
|
| 225 |
+
for i in reversed(range(self.num_levels)):
|
| 226 |
+
h_up = self.up[i](h)
|
| 227 |
+
s = skips[i]
|
| 228 |
+
h = self.fuse[i](h_up.replace(torch.cat([h_up.feats, s.feats], dim=-1)))
|
| 229 |
+
h = self.dec[i](h)
|
| 230 |
+
|
| 231 |
+
u0 = h
|
| 232 |
+
|
| 233 |
+
if in_dtype != u0.feats.dtype:
|
| 234 |
+
u0 = u0.replace(u0.feats.to(dtype=in_dtype))
|
| 235 |
+
return u0
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class AniGenSLatMeshDecoder(AniGenSparseTransformerBase):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
resolution: int,
|
| 242 |
+
model_channels: int,
|
| 243 |
+
model_channels_skl: int,
|
| 244 |
+
model_channels_skin: int,
|
| 245 |
+
|
| 246 |
+
latent_channels: int,
|
| 247 |
+
latent_channels_skl: int,
|
| 248 |
+
latent_channels_vertskin: int,
|
| 249 |
+
|
| 250 |
+
num_blocks: int,
|
| 251 |
+
num_heads: Optional[int] = None,
|
| 252 |
+
num_head_channels: Optional[int] = 64,
|
| 253 |
+
|
| 254 |
+
num_heads_skl: int = 32,
|
| 255 |
+
num_heads_skin: int = 32,
|
| 256 |
+
|
| 257 |
+
skin_cross_from_groupped: bool = False,
|
| 258 |
+
h_skin_unet_num_levels: int = 4,
|
| 259 |
+
|
| 260 |
+
skin_decoder_config: Optional[Dict[str, Any]] = {},
|
| 261 |
+
|
| 262 |
+
upsample_skl: bool = False,
|
| 263 |
+
skl_defined_on_center: bool = True,
|
| 264 |
+
mlp_ratio: float = 4,
|
| 265 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
| 266 |
+
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
|
| 267 |
+
window_size: int = 8,
|
| 268 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 269 |
+
use_fp16: bool = False,
|
| 270 |
+
use_checkpoint: bool = False,
|
| 271 |
+
qk_rms_norm: bool = False,
|
| 272 |
+
representation_config: dict = None,
|
| 273 |
+
|
| 274 |
+
use_pretrain_branch: bool = True,
|
| 275 |
+
freeze_pretrain_branch: bool = True,
|
| 276 |
+
modules_to_freeze: Optional[List[str]] = ["blocks", "upsample", "out_layer", "skin_decoder"],
|
| 277 |
+
|
| 278 |
+
skin_cross_from_geo: bool = False,
|
| 279 |
+
skl_cross_from_geo: bool = False,
|
| 280 |
+
skin_skl_cross: bool = False,
|
| 281 |
+
skin_ae_name: str = "SkinAE",
|
| 282 |
+
|
| 283 |
+
normalize_z: bool = False,
|
| 284 |
+
normalize_scale: float = 1.0,
|
| 285 |
+
|
| 286 |
+
jp_residual_fields: bool = True,
|
| 287 |
+
jp_hyper_continuous: bool = True,
|
| 288 |
+
|
| 289 |
+
grouping_strategy: Literal["mean_shift", "threshold"] = "mean_shift",
|
| 290 |
+
|
| 291 |
+
vertex_skin_feat_interp_sparse: bool = False,
|
| 292 |
+
vertex_skin_feat_interp_nearest: bool = False,
|
| 293 |
+
vertex_skin_feat_interp_use_deformed_grid: bool = False,
|
| 294 |
+
vertex_skin_feat_interp_trilinear: bool = False,
|
| 295 |
+
flexicube_disable_deform: bool = False,
|
| 296 |
+
vertex_skin_feat_nodeform_trilinear: bool = False,
|
| 297 |
+
):
|
| 298 |
+
super().__init__(
|
| 299 |
+
in_channels=latent_channels,
|
| 300 |
+
in_channels_skl=latent_channels_skl,
|
| 301 |
+
in_channels_skin=latent_channels_vertskin,
|
| 302 |
+
model_channels=model_channels,
|
| 303 |
+
model_channels_skl=model_channels_skl,
|
| 304 |
+
model_channels_skin=model_channels_skin,
|
| 305 |
+
num_blocks=num_blocks,
|
| 306 |
+
num_heads=num_heads,
|
| 307 |
+
num_heads_skl=num_heads_skl,
|
| 308 |
+
num_heads_skin=num_heads_skin,
|
| 309 |
+
num_head_channels=num_head_channels,
|
| 310 |
+
mlp_ratio=mlp_ratio,
|
| 311 |
+
attn_mode=attn_mode,
|
| 312 |
+
attn_mode_cross=attn_mode_cross,
|
| 313 |
+
window_size=window_size,
|
| 314 |
+
pe_mode=pe_mode,
|
| 315 |
+
use_fp16=use_fp16,
|
| 316 |
+
use_checkpoint=use_checkpoint,
|
| 317 |
+
qk_rms_norm=qk_rms_norm,
|
| 318 |
+
skin_cross_from_geo=skin_cross_from_geo,
|
| 319 |
+
skl_cross_from_geo=skl_cross_from_geo,
|
| 320 |
+
skin_skl_cross=skin_skl_cross,
|
| 321 |
+
)
|
| 322 |
+
self.pretrain_class_name = ["AniGenElasticSLatMeshDecoder", skin_ae_name]
|
| 323 |
+
self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_decoder"}
|
| 324 |
+
self.latent_channels = latent_channels
|
| 325 |
+
self.latent_channels_skl = latent_channels_skl
|
| 326 |
+
self.latent_channels_vertskin = latent_channels_vertskin
|
| 327 |
+
self.jp_residual_fields = jp_residual_fields
|
| 328 |
+
self.jp_hyper_continuous = jp_hyper_continuous
|
| 329 |
+
self.grouping_func = GROUPING_STRATEGIES[grouping_strategy]
|
| 330 |
+
self.skin_cross_from_groupped = skin_cross_from_groupped
|
| 331 |
+
|
| 332 |
+
self.normalize_z = normalize_z
|
| 333 |
+
self.normalize_scale = normalize_scale
|
| 334 |
+
|
| 335 |
+
skin_decoder_config['use_fp16'] = use_fp16
|
| 336 |
+
self.skin_decoder = SKIN_MODEL_DICT[skin_decoder_config.pop('model_type')](**skin_decoder_config)
|
| 337 |
+
self.skin_feat_channels = self.skin_decoder.skin_feat_channels
|
| 338 |
+
|
| 339 |
+
# Optional local smoothing UNet on h_skin (independent of grouped cross-attn).
|
| 340 |
+
# If `h_skin_unet_num_levels < 0`, UNet is disabled.
|
| 341 |
+
self.h_skin_unet_num_levels = int(h_skin_unet_num_levels)
|
| 342 |
+
if self.h_skin_unet_num_levels >= 1:
|
| 343 |
+
self.h_skin_unet = SparseSkinUNetNLevel(
|
| 344 |
+
model_channels_skin,
|
| 345 |
+
base_resolution=resolution,
|
| 346 |
+
num_levels=self.h_skin_unet_num_levels,
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
self.h_skin_unet = None
|
| 350 |
+
|
| 351 |
+
if self.skin_cross_from_groupped:
|
| 352 |
+
# Trainable parent feature for root joints (where parent_idx < 0).
|
| 353 |
+
self.root_parent_feat = nn.Parameter(torch.zeros(self.skin_feat_channels))
|
| 354 |
+
|
| 355 |
+
# Joint feature preprocessing: [joint_skin, fourier(joint_xyz), parent_skin] -> proj -> self-attn
|
| 356 |
+
self.joints_pos_embedder = FreqPositionalEmbedder(
|
| 357 |
+
in_dim=3,
|
| 358 |
+
include_input=True,
|
| 359 |
+
max_freq_log2=6,
|
| 360 |
+
num_freqs=6,
|
| 361 |
+
log_sampling=True,
|
| 362 |
+
)
|
| 363 |
+
joints_pe_dim = self.joints_pos_embedder.out_dim
|
| 364 |
+
joints_in_dim = self.skin_feat_channels + joints_pe_dim + self.skin_feat_channels
|
| 365 |
+
self.joints_ctx_channels = model_channels_skin
|
| 366 |
+
self.joints_in_proj = nn.Sequential(
|
| 367 |
+
nn.Linear(joints_in_dim, self.joints_ctx_channels, bias=True),
|
| 368 |
+
nn.SiLU(),
|
| 369 |
+
nn.LayerNorm(self.joints_ctx_channels, elementwise_affine=True),
|
| 370 |
+
)
|
| 371 |
+
self.joints_self_attn = nn.ModuleList([
|
| 372 |
+
SparseTransformerBlock(
|
| 373 |
+
self.joints_ctx_channels,
|
| 374 |
+
num_heads=num_heads_skin,
|
| 375 |
+
mlp_ratio=self.mlp_ratio,
|
| 376 |
+
attn_mode="full",
|
| 377 |
+
window_size=None,
|
| 378 |
+
use_checkpoint=self.use_checkpoint,
|
| 379 |
+
use_rope=False,
|
| 380 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 381 |
+
ln_affine=True,
|
| 382 |
+
) for _ in range(4)
|
| 383 |
+
])
|
| 384 |
+
|
| 385 |
+
# Coordinate PE for h_skin before cross-attn: coords in [-1, 1] -> Fourier PE -> proj(C), concat, fuse back to C.
|
| 386 |
+
self.h_skin_coord_embedder = FreqPositionalEmbedder(
|
| 387 |
+
in_dim=3,
|
| 388 |
+
include_input=True,
|
| 389 |
+
max_freq_log2=6,
|
| 390 |
+
num_freqs=6,
|
| 391 |
+
log_sampling=True,
|
| 392 |
+
)
|
| 393 |
+
h_skin_pe_dim = self.h_skin_coord_embedder.out_dim
|
| 394 |
+
self.h_skin_coord_proj = nn.Linear(h_skin_pe_dim, model_channels_skin, bias=True)
|
| 395 |
+
self.h_skin_coord_fuse = sp.SparseLinear(model_channels_skin * 2, model_channels_skin)
|
| 396 |
+
|
| 397 |
+
self.skin_cross_groupped_net = SparseTransformerMultiContextCrossBlock(
|
| 398 |
+
model_channels_skin,
|
| 399 |
+
# Context includes processed joint tokens + raw joint skin feats (skip connection).
|
| 400 |
+
ctx_channels=[self.joints_ctx_channels + self.skin_feat_channels],
|
| 401 |
+
num_heads=num_heads_skin,
|
| 402 |
+
mlp_ratio=self.mlp_ratio,
|
| 403 |
+
attn_mode="full",
|
| 404 |
+
attn_mode_cross="full",
|
| 405 |
+
cross_attn_cache_suffix='_skin_cross_from_groupped',
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.resolution = resolution
|
| 409 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 410 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch
|
| 411 |
+
self.upsample_skl = upsample_skl
|
| 412 |
+
self.rep_config = representation_config
|
| 413 |
+
self.mesh_extractor = AniGenSparseFeatures2Mesh(
|
| 414 |
+
res=self.resolution*4,
|
| 415 |
+
use_color=self.rep_config.get('use_color', False),
|
| 416 |
+
skin_feat_channels=self.skin_feat_channels,
|
| 417 |
+
predict_skin=True,
|
| 418 |
+
vertex_skin_feat_interp_sparse=vertex_skin_feat_interp_sparse,
|
| 419 |
+
vertex_skin_feat_interp_nearest=vertex_skin_feat_interp_nearest,
|
| 420 |
+
vertex_skin_feat_interp_use_deformed_grid=vertex_skin_feat_interp_use_deformed_grid,
|
| 421 |
+
vertex_skin_feat_interp_trilinear=vertex_skin_feat_interp_trilinear,
|
| 422 |
+
flexicube_disable_deform=flexicube_disable_deform,
|
| 423 |
+
vertex_skin_feat_nodeform_trilinear=vertex_skin_feat_nodeform_trilinear,
|
| 424 |
+
)
|
| 425 |
+
self.out_channels = self.mesh_extractor.feats_channels
|
| 426 |
+
self.upsample = nn.ModuleList([
|
| 427 |
+
SparseSubdivideBlock3d(
|
| 428 |
+
channels=model_channels,
|
| 429 |
+
resolution=resolution,
|
| 430 |
+
out_channels=model_channels // 4
|
| 431 |
+
),
|
| 432 |
+
SparseSubdivideBlock3d(
|
| 433 |
+
channels=model_channels // 4,
|
| 434 |
+
resolution=resolution * 2,
|
| 435 |
+
out_channels=model_channels // 8
|
| 436 |
+
)
|
| 437 |
+
])
|
| 438 |
+
upsample_skin_blocks = []
|
| 439 |
+
upsample_skin_blocks.extend([
|
| 440 |
+
SparseSubdivideBlock3d(
|
| 441 |
+
channels=model_channels_skin,
|
| 442 |
+
resolution=resolution,
|
| 443 |
+
out_channels=model_channels // 4
|
| 444 |
+
),
|
| 445 |
+
SparseSubdivideBlock3d(
|
| 446 |
+
channels=model_channels // 4,
|
| 447 |
+
resolution=resolution * 2,
|
| 448 |
+
out_channels=model_channels // 8
|
| 449 |
+
)
|
| 450 |
+
])
|
| 451 |
+
|
| 452 |
+
self.upsample_skin_net = nn.ModuleList(upsample_skin_blocks)
|
| 453 |
+
self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
|
| 454 |
+
self.out_layer_skin = sp.SparseLinear(model_channels // 8, self.skin_feat_channels*8)
|
| 455 |
+
self.out_layer_skl_skin = sp.SparseLinear(model_channels // 8 if upsample_skl else model_channels_skl, self.skin_feat_channels if skl_defined_on_center else self.skin_feat_channels * 8)
|
| 456 |
+
self.use_conf_jp = self.rep_config.get('use_conf_jp', False) or self.jp_hyper_continuous
|
| 457 |
+
self.use_conf_skin = self.rep_config.get('use_conf_skin', False)
|
| 458 |
+
|
| 459 |
+
res_skl = self.resolution * 4 if self.upsample_skl else self.resolution
|
| 460 |
+
self.skeleton_extractor = AniGenSklFeatures2Skeleton(skin_feat_channels=self.skin_feat_channels, device=self.device, res=res_skl, use_conf_jp=self.use_conf_jp, use_conf_skin=self.use_conf_skin, predict_skin=True, defined_on_center=skl_defined_on_center, jp_hyper_continuous=self.jp_hyper_continuous, jp_residual_fields=self.jp_residual_fields)
|
| 461 |
+
|
| 462 |
+
self.out_channels_skl = self.skeleton_extractor.feats_channels
|
| 463 |
+
if self.upsample_skl:
|
| 464 |
+
self.upsample_skl_net = nn.ModuleList([
|
| 465 |
+
SparseSubdivideBlock3d(
|
| 466 |
+
channels=model_channels_skl,
|
| 467 |
+
resolution=resolution,
|
| 468 |
+
out_channels=model_channels // 4
|
| 469 |
+
),
|
| 470 |
+
SparseSubdivideBlock3d(
|
| 471 |
+
channels=model_channels // 4,
|
| 472 |
+
resolution=resolution * 2,
|
| 473 |
+
out_channels=model_channels // 8
|
| 474 |
+
)
|
| 475 |
+
])
|
| 476 |
+
self.out_layer_skl = sp.SparseLinear(model_channels // 8, self.out_channels_skl)
|
| 477 |
+
else:
|
| 478 |
+
self.out_layer_skl = sp.SparseLinear(model_channels_skl, self.out_channels_skl)
|
| 479 |
+
|
| 480 |
+
self.initialize_weights()
|
| 481 |
+
if use_fp16:
|
| 482 |
+
self.convert_to_fp16()
|
| 483 |
+
else:
|
| 484 |
+
self.convert_to_fp32()
|
| 485 |
+
|
| 486 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 487 |
+
for module in modules_to_freeze:
|
| 488 |
+
if hasattr(self, module):
|
| 489 |
+
mod = getattr(self, module)
|
| 490 |
+
if isinstance(mod, nn.ModuleList):
|
| 491 |
+
for m in mod:
|
| 492 |
+
for name, param in m.named_parameters():
|
| 493 |
+
if 'lora' not in name:
|
| 494 |
+
param.requires_grad = False
|
| 495 |
+
elif isinstance(mod, nn.Module):
|
| 496 |
+
for name, param in mod.named_parameters():
|
| 497 |
+
if 'lora' not in name:
|
| 498 |
+
param.requires_grad = False
|
| 499 |
+
elif isinstance(mod, torch.Tensor):
|
| 500 |
+
if mod.requires_grad:
|
| 501 |
+
mod.requires_grad = False
|
| 502 |
+
|
| 503 |
+
def initialize_weights(self) -> None:
|
| 504 |
+
super().initialize_weights()
|
| 505 |
+
scale = 1e-4
|
| 506 |
+
# Kaiming initialization for output layers (better for ReLU/SiLU-like activations)
|
| 507 |
+
nn.init.kaiming_normal_(self.out_layer.weight, mode='fan_in', nonlinearity='relu')
|
| 508 |
+
self.out_layer.weight.data.mul_(scale)
|
| 509 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 510 |
+
|
| 511 |
+
nn.init.kaiming_normal_(self.out_layer_skl.weight, mode='fan_in', nonlinearity='relu')
|
| 512 |
+
self.out_layer_skl.weight.data.mul_(scale)
|
| 513 |
+
nn.init.constant_(self.out_layer_skl.bias, 0)
|
| 514 |
+
|
| 515 |
+
# Initialize skin layer:
|
| 516 |
+
self.skin_decoder.initialize_weights()
|
| 517 |
+
nn.init.kaiming_normal_(self.out_layer_skin.weight, mode='fan_in', nonlinearity='relu')
|
| 518 |
+
self.out_layer_skin.weight.data.mul_(scale)
|
| 519 |
+
nn.init.constant_(self.out_layer_skin.bias, 0)
|
| 520 |
+
|
| 521 |
+
nn.init.kaiming_normal_(self.out_layer_skl_skin.weight, mode='fan_in', nonlinearity='relu')
|
| 522 |
+
self.out_layer_skl_skin.weight.data.mul_(scale)
|
| 523 |
+
nn.init.constant_(self.out_layer_skl_skin.bias, 0)
|
| 524 |
+
|
| 525 |
+
def convert_to_fp16(self) -> None:
|
| 526 |
+
"""
|
| 527 |
+
Convert the torso of the model to float16.
|
| 528 |
+
"""
|
| 529 |
+
super().convert_to_fp16()
|
| 530 |
+
self.upsample.apply(convert_module_to_f16)
|
| 531 |
+
self.upsample_skin_net.apply(convert_module_to_f16)
|
| 532 |
+
if self.upsample_skl:
|
| 533 |
+
self.upsample_skl_net.apply(convert_module_to_f16)
|
| 534 |
+
if self.skin_cross_from_groupped:
|
| 535 |
+
# Joint preprocessing and cross-attn should match model dtype.
|
| 536 |
+
self.root_parent_feat.data = self.root_parent_feat.data.half()
|
| 537 |
+
self.joints_in_proj.apply(convert_module_to_f16)
|
| 538 |
+
self.joints_self_attn.apply(convert_module_to_f16)
|
| 539 |
+
|
| 540 |
+
# `convert_module_to_f16` doesn't include `nn.LayerNorm`, so cast LN params explicitly.
|
| 541 |
+
for _m in self.joints_in_proj.modules():
|
| 542 |
+
if isinstance(_m, nn.LayerNorm):
|
| 543 |
+
if _m.weight is not None:
|
| 544 |
+
_m.weight.data = _m.weight.data.half()
|
| 545 |
+
if _m.bias is not None:
|
| 546 |
+
_m.bias.data = _m.bias.data.half()
|
| 547 |
+
|
| 548 |
+
# IMPORTANT: `SparseTransformerBlock` uses `LayerNorm32` which internally
|
| 549 |
+
# normalizes in fp32 (`x.float()`), so its parameters must stay fp32.
|
| 550 |
+
for _m in self.joints_self_attn.modules():
|
| 551 |
+
if isinstance(_m, nn.LayerNorm):
|
| 552 |
+
if _m.weight is not None:
|
| 553 |
+
_m.weight.data = _m.weight.data.float()
|
| 554 |
+
if _m.bias is not None:
|
| 555 |
+
_m.bias.data = _m.bias.data.float()
|
| 556 |
+
|
| 557 |
+
self.skin_cross_groupped_net.apply(convert_module_to_f16)
|
| 558 |
+
self.h_skin_coord_proj.apply(convert_module_to_f16)
|
| 559 |
+
self.h_skin_coord_fuse.apply(convert_module_to_f16)
|
| 560 |
+
|
| 561 |
+
# UNet is executed in fp32 (see `SparseSkinUNetNLevel.forward`), so keep its
|
| 562 |
+
# weights in fp32 to avoid dtype mismatches inside spconv.
|
| 563 |
+
if self.h_skin_unet is not None:
|
| 564 |
+
self.h_skin_unet.apply(convert_module_to_f32)
|
| 565 |
+
self.skin_decoder.convert_to_fp16()
|
| 566 |
+
|
| 567 |
+
def convert_to_fp32(self) -> None:
|
| 568 |
+
"""
|
| 569 |
+
Convert the torso of the model to float32.
|
| 570 |
+
"""
|
| 571 |
+
super().convert_to_fp32()
|
| 572 |
+
self.upsample.apply(convert_module_to_f32)
|
| 573 |
+
self.upsample_skin_net.apply(convert_module_to_f32)
|
| 574 |
+
if self.upsample_skl:
|
| 575 |
+
self.upsample_skl_net.apply(convert_module_to_f32)
|
| 576 |
+
if self.skin_cross_from_groupped:
|
| 577 |
+
self.root_parent_feat.data = self.root_parent_feat.data.float()
|
| 578 |
+
self.joints_in_proj.apply(convert_module_to_f32)
|
| 579 |
+
self.joints_self_attn.apply(convert_module_to_f32)
|
| 580 |
+
|
| 581 |
+
for _m in self.joints_in_proj.modules():
|
| 582 |
+
if isinstance(_m, nn.LayerNorm):
|
| 583 |
+
if _m.weight is not None:
|
| 584 |
+
_m.weight.data = _m.weight.data.float()
|
| 585 |
+
if _m.bias is not None:
|
| 586 |
+
_m.bias.data = _m.bias.data.float()
|
| 587 |
+
for _m in self.joints_self_attn.modules():
|
| 588 |
+
if isinstance(_m, nn.LayerNorm):
|
| 589 |
+
if _m.weight is not None:
|
| 590 |
+
_m.weight.data = _m.weight.data.float()
|
| 591 |
+
if _m.bias is not None:
|
| 592 |
+
_m.bias.data = _m.bias.data.float()
|
| 593 |
+
|
| 594 |
+
self.skin_cross_groupped_net.apply(convert_module_to_f32)
|
| 595 |
+
self.h_skin_coord_proj.apply(convert_module_to_f32)
|
| 596 |
+
self.h_skin_coord_fuse.apply(convert_module_to_f32)
|
| 597 |
+
if self.h_skin_unet is not None:
|
| 598 |
+
self.h_skin_unet.apply(convert_module_to_f32)
|
| 599 |
+
self.skin_decoder.convert_to_fp32()
|
| 600 |
+
|
| 601 |
+
def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
|
| 602 |
+
"""
|
| 603 |
+
Convert a batch of network outputs to 3D representations.
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
x: The [N x * x C] sparse tensor output by the network.
|
| 607 |
+
|
| 608 |
+
Returns:
|
| 609 |
+
list of representations
|
| 610 |
+
"""
|
| 611 |
+
ret = []
|
| 612 |
+
for i in range(x.shape[0]):
|
| 613 |
+
mesh = self.mesh_extractor(x[i], training=self.training)
|
| 614 |
+
ret.append(mesh)
|
| 615 |
+
return ret
|
| 616 |
+
|
| 617 |
+
def to_representation_skl(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
|
| 618 |
+
"""
|
| 619 |
+
Convert a batch of network outputs to skeleton representations.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
x: The [N x * x C] sparse tensor output by the network.
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
list of skeleton representations
|
| 626 |
+
"""
|
| 627 |
+
ret = []
|
| 628 |
+
for i in range(x.shape[0]):
|
| 629 |
+
skl = self.skeleton_extractor(x[i], training=self.training)
|
| 630 |
+
ret.append(skl)
|
| 631 |
+
return ret
|
| 632 |
+
|
| 633 |
+
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, gt_joints=None, gt_parents=None) -> List[MeshExtractResult]:
|
| 634 |
+
x0 = x
|
| 635 |
+
x_skin = sp.SparseTensor(feats=x0.feats[:, self.latent_channels:], coords=x0.coords.clone())
|
| 636 |
+
x = x0.replace(x0.feats[:, :self.latent_channels])
|
| 637 |
+
if self.normalize_z:
|
| 638 |
+
x_skin = x_skin.replace(F.normalize(x_skin.feats, dim=-1))
|
| 639 |
+
x_skl = x_skl.replace(F.normalize(x_skl.feats, dim=-1))
|
| 640 |
+
|
| 641 |
+
# Backbone forward
|
| 642 |
+
h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
|
| 643 |
+
|
| 644 |
+
# Optional smoothing on h_skin.
|
| 645 |
+
if self.h_skin_unet is not None:
|
| 646 |
+
h_skin = self.h_skin_unet(h_skin)
|
| 647 |
+
|
| 648 |
+
# Skeleton prediction
|
| 649 |
+
if self.upsample_skl:
|
| 650 |
+
for block_skl in self.upsample_skl_net:
|
| 651 |
+
h_skl = block_skl(h_skl)
|
| 652 |
+
h_skl_middle = h_skl.type(x_skl.dtype)
|
| 653 |
+
h_skl = self.out_layer_skl(h_skl_middle)
|
| 654 |
+
h_skl_skin = self.out_layer_skl_skin(h_skl_middle)
|
| 655 |
+
h_skl = h_skl.replace(torch.cat([h_skl.feats, h_skl_skin.feats], dim=-1))
|
| 656 |
+
skeletons = self.to_representation_skl(h_skl)
|
| 657 |
+
skin_feats_joints_list = self.skeleton_grouping(skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
|
| 658 |
+
|
| 659 |
+
# Skin cross with grouped joint features
|
| 660 |
+
if self.skin_cross_from_groupped:
|
| 661 |
+
coords_xyz = h_skin.coords[:, 1:].to(device=h_skin.device, dtype=torch.float32)
|
| 662 |
+
coords_norm = (coords_xyz + 0.5) / self.resolution * 2.0 - 1.0
|
| 663 |
+
coords_pe = self.h_skin_coord_embedder(coords_norm)
|
| 664 |
+
coords_pe = coords_pe.to(device=h_skin.device, dtype=h_skin.feats.dtype)
|
| 665 |
+
coords_pe = self.h_skin_coord_proj(coords_pe)
|
| 666 |
+
h_skin = h_skin.replace(torch.cat([h_skin.feats, coords_pe], dim=-1))
|
| 667 |
+
h_skin = self.h_skin_coord_fuse(h_skin)
|
| 668 |
+
joints_ctx = self._build_processed_joints_context(
|
| 669 |
+
skeletons,
|
| 670 |
+
skin_feats_joints_list,
|
| 671 |
+
device=h_skin.device,
|
| 672 |
+
dtype=h_skin.feats.dtype,
|
| 673 |
+
)
|
| 674 |
+
h_skin = self.skin_cross_groupped_net(h_skin, [joints_ctx])
|
| 675 |
+
for block in self.upsample_skin_net:
|
| 676 |
+
h_skin = block(h_skin)
|
| 677 |
+
h_skin = h_skin.type(x.dtype)
|
| 678 |
+
h_skin = self.out_layer_skin(h_skin)
|
| 679 |
+
|
| 680 |
+
# Mesh prediction
|
| 681 |
+
for block in self.upsample:
|
| 682 |
+
h = block(h)
|
| 683 |
+
h_middle = h.type(x.dtype)
|
| 684 |
+
h = self.out_layer(h_middle)
|
| 685 |
+
h = h.replace(torch.cat([h.feats, h_skin.feats], dim=-1))
|
| 686 |
+
meshes = self.to_representation(h)
|
| 687 |
+
|
| 688 |
+
self.skinweight_forward(meshes, skeletons, gt_joints=gt_joints, gt_parents=gt_parents)
|
| 689 |
+
return meshes, skeletons
|
| 690 |
+
|
| 691 |
+
def _joints_feats_list_to_sparse(
|
| 692 |
+
self,
|
| 693 |
+
joints_feats_list: List[torch.Tensor],
|
| 694 |
+
device: Optional[torch.device] = None,
|
| 695 |
+
dtype: Optional[torch.dtype] = None,
|
| 696 |
+
) -> sp.SparseTensor:
|
| 697 |
+
if device is None:
|
| 698 |
+
device = self.device
|
| 699 |
+
if dtype is None:
|
| 700 |
+
dtype = self.dtype
|
| 701 |
+
feats_per_batch: List[torch.Tensor] = []
|
| 702 |
+
for joints_feats in joints_feats_list:
|
| 703 |
+
joints_feats = joints_feats.to(device=device, dtype=dtype)
|
| 704 |
+
feats_per_batch.append(joints_feats)
|
| 705 |
+
feats = torch.cat(feats_per_batch, dim=0)
|
| 706 |
+
# Coords are [batch, x, y, z]. We encode token index into x and keep y/z = 0.
|
| 707 |
+
batch_indices: List[torch.Tensor] = []
|
| 708 |
+
x_indices: List[torch.Tensor] = []
|
| 709 |
+
for bi, joints_feats in enumerate(feats_per_batch):
|
| 710 |
+
ji = int(joints_feats.shape[0])
|
| 711 |
+
batch_indices.append(torch.full((ji,), bi, device=device, dtype=torch.int32))
|
| 712 |
+
x_indices.append(torch.arange(ji, device=device, dtype=torch.int32))
|
| 713 |
+
b = torch.cat(batch_indices, dim=0)
|
| 714 |
+
x = torch.cat(x_indices, dim=0)
|
| 715 |
+
yz = torch.zeros((x.shape[0], 2), device=device, dtype=torch.int32)
|
| 716 |
+
coords = torch.cat([b[:, None], x[:, None], yz], dim=1)
|
| 717 |
+
return sp.SparseTensor(feats=feats, coords=coords)
|
| 718 |
+
|
| 719 |
+
def _build_processed_joints_context(
|
| 720 |
+
self,
|
| 721 |
+
skeletons: List[Any],
|
| 722 |
+
skin_feats_joints_list: List[torch.Tensor],
|
| 723 |
+
device: torch.device,
|
| 724 |
+
dtype: torch.dtype,
|
| 725 |
+
) -> sp.SparseTensor:
|
| 726 |
+
processed: List[torch.Tensor] = []
|
| 727 |
+
raw_skin: List[torch.Tensor] = []
|
| 728 |
+
for rep_skl, skin_feats_joints in zip(skeletons, skin_feats_joints_list):
|
| 729 |
+
joints = rep_skl.joints_grouped
|
| 730 |
+
parents = rep_skl.parents_grouped
|
| 731 |
+
if joints is None or parents is None:
|
| 732 |
+
raise ValueError('Expected grouped joints/parents for skin_cross_from_groupped.')
|
| 733 |
+
joints = joints.to(device=device, dtype=dtype)
|
| 734 |
+
parents = parents.to(device=device)
|
| 735 |
+
skin_feats_joints = skin_feats_joints.to(device=device, dtype=dtype)
|
| 736 |
+
raw_skin.append(skin_feats_joints)
|
| 737 |
+
|
| 738 |
+
pe = self.joints_pos_embedder(joints).to(device=device, dtype=dtype)
|
| 739 |
+
|
| 740 |
+
# Parent skin features (root uses trainable parameter)
|
| 741 |
+
parent_idx = parents.to(torch.long)
|
| 742 |
+
valid = parent_idx >= 0
|
| 743 |
+
root_feat = self.root_parent_feat.to(device=device, dtype=dtype)
|
| 744 |
+
parent_feat_root = root_feat.unsqueeze(0).expand(skin_feats_joints.shape[0], -1)
|
| 745 |
+
parent_feat_gather = skin_feats_joints[parent_idx.clamp(min=0)]
|
| 746 |
+
parent_feat = torch.where(valid.unsqueeze(1), parent_feat_gather, parent_feat_root)
|
| 747 |
+
|
| 748 |
+
joint_in = torch.cat([skin_feats_joints, pe, parent_feat], dim=-1)
|
| 749 |
+
joint_h = self.joints_in_proj(joint_in)
|
| 750 |
+
processed.append(joint_h)
|
| 751 |
+
|
| 752 |
+
joints_ctx = self._joints_feats_list_to_sparse(processed, device=device, dtype=dtype)
|
| 753 |
+
for blk in self.joints_self_attn:
|
| 754 |
+
joints_ctx = blk(joints_ctx)
|
| 755 |
+
# Skip connection: concatenate original joint skin feats after self-attn.
|
| 756 |
+
joints_skip = self._joints_feats_list_to_sparse(raw_skin, device=device, dtype=dtype)
|
| 757 |
+
joints_ctx = joints_ctx.replace(torch.cat([joints_ctx.feats, joints_skip.feats], dim=-1))
|
| 758 |
+
return joints_ctx
|
| 759 |
+
|
| 760 |
+
def skeleton_grouping(self, reps_skl, gt_joints=None, gt_parents=None, skin_feats_skl_list=None, return_skin_pred_only=False):
|
| 761 |
+
skin_feats_joints_list = []
|
| 762 |
+
for i, rep_skl in zip(range(len(reps_skl)), reps_skl):
|
| 763 |
+
if gt_joints is not None:
|
| 764 |
+
joints_grouped = gt_joints[i]
|
| 765 |
+
parents_grouped = gt_parents[i]
|
| 766 |
+
elif rep_skl.joints_grouped is None:
|
| 767 |
+
with torch.no_grad():
|
| 768 |
+
joints_grouped, parents_grouped = self.grouping_func(joints=rep_skl.joints, parents=rep_skl.parents, joints_conf=rep_skl.conf_j, parents_conf=rep_skl.conf_p)
|
| 769 |
+
else:
|
| 770 |
+
joints_grouped = rep_skl.joints_grouped
|
| 771 |
+
parents_grouped = rep_skl.parents_grouped
|
| 772 |
+
|
| 773 |
+
if not return_skin_pred_only:
|
| 774 |
+
rep_skl.joints_grouped = joints_grouped
|
| 775 |
+
rep_skl.parents_grouped = parents_grouped
|
| 776 |
+
|
| 777 |
+
# Calculate NN indices for joints
|
| 778 |
+
positions_skl = rep_skl.positions
|
| 779 |
+
_, joints_nn_idx, _ = knn_points(positions_skl[None], joints_grouped[None].detach(), K=1, norm=2, return_nn=False)
|
| 780 |
+
joints_nn_idx = joints_nn_idx[0, :, 0]
|
| 781 |
+
skin_feats_skl = rep_skl.skin_feats if skin_feats_skl_list is None else skin_feats_skl_list[i]
|
| 782 |
+
|
| 783 |
+
# Average the predicted joint features
|
| 784 |
+
conf_skin = torch.sigmoid(rep_skl.conf_skin) if rep_skl.conf_skin is not None else torch.ones_like(skin_feats_skl[:, :1])
|
| 785 |
+
|
| 786 |
+
skin_feats_joints = torch.zeros([joints_grouped.shape[0], skin_feats_skl.shape[-1]], device=self.device, dtype=skin_feats_skl.dtype)
|
| 787 |
+
skin_feats_square_joints = skin_feats_joints.clone()
|
| 788 |
+
skin_conf_joints = torch.zeros([joints_grouped.shape[0], 1], device=self.device, dtype=skin_feats_skl.dtype)
|
| 789 |
+
|
| 790 |
+
skin_feats_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl * conf_skin)
|
| 791 |
+
skin_feats_square_joints.scatter_add_(0, joints_nn_idx[:, None].expand(-1, skin_feats_skl.shape[-1]), skin_feats_skl.square() * conf_skin)
|
| 792 |
+
skin_conf_joints.scatter_add_(0, joints_nn_idx[:, None], conf_skin)
|
| 793 |
+
|
| 794 |
+
skin_feats_joints = skin_feats_joints / skin_conf_joints.clamp(min=1e-6)
|
| 795 |
+
skin_feats_square_joints = skin_feats_square_joints / skin_conf_joints.clamp(min=1e-6)
|
| 796 |
+
skin_feats_joints_var = skin_feats_square_joints - skin_feats_joints.square()
|
| 797 |
+
skin_feats_joints_var_loss = skin_feats_joints_var.mean()
|
| 798 |
+
|
| 799 |
+
if not return_skin_pred_only:
|
| 800 |
+
rep_skl.skin_feats_joints_var_loss = skin_feats_joints_var_loss
|
| 801 |
+
rep_skl.skin_feats_joints = skin_feats_joints
|
| 802 |
+
skin_feats_joints_list.append(skin_feats_joints)
|
| 803 |
+
return skin_feats_joints_list
|
| 804 |
+
|
| 805 |
+
def skinweight_forward(self, reps, reps_skl, gt_joints=None, gt_parents=None, return_skin_pred_only=False, skin_feats_verts_list=None, skin_feats_skl_list=None, *args, **kwargs):
|
| 806 |
+
if return_skin_pred_only:
|
| 807 |
+
skin_preds = []
|
| 808 |
+
if reps_skl[0].parents_grouped is None or return_skin_pred_only:
|
| 809 |
+
skin_feats_joints_list = self.skeleton_grouping(reps_skl, gt_joints=gt_joints, gt_parents=gt_parents, skin_feats_skl_list=skin_feats_skl_list, return_skin_pred_only=return_skin_pred_only)
|
| 810 |
+
else:
|
| 811 |
+
skin_feats_joints_list = [rep_skl.skin_feats_joints for rep_skl in reps_skl]
|
| 812 |
+
for i, rep, rep_skl in zip(range(len(reps)), reps, reps_skl):
|
| 813 |
+
# Joint skinning features
|
| 814 |
+
skin_feats_joints = skin_feats_joints_list[i]
|
| 815 |
+
# Vertex skinning features
|
| 816 |
+
skin_feats_verts = rep.vertex_skin_feats if skin_feats_verts_list is None else skin_feats_verts_list[i]
|
| 817 |
+
# Predict skin weights
|
| 818 |
+
parents_grouped = rep_skl.parents_grouped
|
| 819 |
+
skin_pred = self.skin_decoder(skin_feats_verts[None], skin_feats_joints[None], parents_grouped[None])
|
| 820 |
+
skin_pred = skin_pred[0]
|
| 821 |
+
if return_skin_pred_only:
|
| 822 |
+
skin_preds.append(skin_pred)
|
| 823 |
+
else:
|
| 824 |
+
reps_skl[i].skin_pred = skin_pred
|
| 825 |
+
if return_skin_pred_only:
|
| 826 |
+
return skin_preds
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
class AniGenElasticSLatMeshDecoder(SparseTransformerElasticMixin, AniGenSLatMeshDecoder):
|
| 830 |
+
"""
|
| 831 |
+
Slat VAE Mesh decoder with elastic memory management.
|
| 832 |
+
Used for training with low VRAM.
|
| 833 |
+
"""
|
| 834 |
+
pass
|
anigen/models/structured_latent_vae/anigen_encoder.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from ...modules import sparse as sp
|
| 6 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
| 7 |
+
from .anigen_base import AniGenSparseTransformerBase, FreqPositionalEmbedder
|
| 8 |
+
from pytorch3d.ops import knn_points
|
| 9 |
+
from .skin_models import SkinEncoder
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def block_attn_config(self):
|
| 13 |
+
"""
|
| 14 |
+
Return the attention configuration of the model.
|
| 15 |
+
"""
|
| 16 |
+
for i in range(self.num_blocks):
|
| 17 |
+
if self.attn_mode == "shift_window":
|
| 18 |
+
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
| 19 |
+
elif self.attn_mode == "shift_sequence":
|
| 20 |
+
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
| 21 |
+
elif self.attn_mode == "shift_order":
|
| 22 |
+
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
| 23 |
+
elif self.attn_mode == "full":
|
| 24 |
+
yield "full", None, None, None, None
|
| 25 |
+
elif self.attn_mode == "swin":
|
| 26 |
+
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FeedForwardNet(nn.Module):
|
| 30 |
+
def __init__(self, channels: int, channels_out: int=None, mlp_ratio: float = 4.0):
|
| 31 |
+
super().__init__()
|
| 32 |
+
channels_out = channels if channels_out is None else channels_out
|
| 33 |
+
self.mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(channels, int(channels * mlp_ratio)),
|
| 35 |
+
nn.GELU(approximate="tanh"),
|
| 36 |
+
nn.Linear(int(channels * mlp_ratio), channels_out),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
return self.mlp(x)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class AniGenSLatEncoder(AniGenSparseTransformerBase):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
resolution: int,
|
| 47 |
+
in_channels: int,
|
| 48 |
+
|
| 49 |
+
model_channels: int,
|
| 50 |
+
model_channels_skl: int,
|
| 51 |
+
model_channels_skin: int,
|
| 52 |
+
|
| 53 |
+
latent_channels: int,
|
| 54 |
+
latent_channels_skl: int,
|
| 55 |
+
latent_channels_vertskin: int,
|
| 56 |
+
|
| 57 |
+
num_blocks: int,
|
| 58 |
+
num_heads: Optional[int] = None,
|
| 59 |
+
num_head_channels: Optional[int] = 64,
|
| 60 |
+
|
| 61 |
+
num_heads_skl: int = 32,
|
| 62 |
+
num_heads_skin: int = 32,
|
| 63 |
+
|
| 64 |
+
skl_pos_embed_freq: int = 10,
|
| 65 |
+
skin_encoder_config: Optional[Dict[str, Any]] = {},
|
| 66 |
+
encode_upsampled_skin_feat: bool = True,
|
| 67 |
+
skin_ae_name: Optional[str] = "SkinAE",
|
| 68 |
+
|
| 69 |
+
mlp_ratio: float = 4,
|
| 70 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
| 71 |
+
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
|
| 72 |
+
window_size: int = 8,
|
| 73 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 74 |
+
use_fp16: bool = False,
|
| 75 |
+
use_checkpoint: bool = False,
|
| 76 |
+
qk_rms_norm: bool = False,
|
| 77 |
+
|
| 78 |
+
use_pretrain_branch: bool = True,
|
| 79 |
+
freeze_pretrain_branch: bool = True,
|
| 80 |
+
modules_to_freeze: Optional[List[str]] = ["input_layer", "blocks", "out_layer", "skin_encoder"],
|
| 81 |
+
|
| 82 |
+
skin_cross_from_geo: bool = True,
|
| 83 |
+
skl_cross_from_geo: bool = True,
|
| 84 |
+
skin_skl_cross: bool = True,
|
| 85 |
+
|
| 86 |
+
latent_denoising: bool = True,
|
| 87 |
+
normalize_z: bool = True,
|
| 88 |
+
normalize_scale: float = 1.0,
|
| 89 |
+
|
| 90 |
+
jp_residual_fields: bool = False,
|
| 91 |
+
jp_hyper_continuous: bool = False,
|
| 92 |
+
):
|
| 93 |
+
self.use_pretrain_branch = use_pretrain_branch
|
| 94 |
+
self.freeze_pretrain_branch = freeze_pretrain_branch
|
| 95 |
+
self.skl_pos_embed_freq = skl_pos_embed_freq
|
| 96 |
+
self.latent_denoising = latent_denoising
|
| 97 |
+
self.normalize_latent = normalize_z and latent_denoising
|
| 98 |
+
self.normalize_scale = normalize_scale
|
| 99 |
+
self.jp_residual_fields = jp_residual_fields
|
| 100 |
+
self.jp_hyper_continuous = jp_hyper_continuous
|
| 101 |
+
|
| 102 |
+
super().__init__(
|
| 103 |
+
in_channels=in_channels,
|
| 104 |
+
in_channels_skl=model_channels_skl,
|
| 105 |
+
in_channels_skin=model_channels_skin,
|
| 106 |
+
model_channels=model_channels,
|
| 107 |
+
model_channels_skl=model_channels_skl,
|
| 108 |
+
model_channels_skin=model_channels_skin,
|
| 109 |
+
num_blocks=num_blocks,
|
| 110 |
+
num_heads=num_heads,
|
| 111 |
+
num_heads_skl=num_heads_skl,
|
| 112 |
+
num_heads_skin=num_heads_skin,
|
| 113 |
+
num_head_channels=num_head_channels,
|
| 114 |
+
mlp_ratio=mlp_ratio,
|
| 115 |
+
attn_mode=attn_mode,
|
| 116 |
+
attn_mode_cross=attn_mode_cross,
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
pe_mode=pe_mode,
|
| 119 |
+
use_fp16=use_fp16,
|
| 120 |
+
use_checkpoint=use_checkpoint,
|
| 121 |
+
qk_rms_norm=qk_rms_norm,
|
| 122 |
+
skin_cross_from_geo=skin_cross_from_geo,
|
| 123 |
+
skl_cross_from_geo=skl_cross_from_geo,
|
| 124 |
+
skin_skl_cross=skin_skl_cross,
|
| 125 |
+
)
|
| 126 |
+
self.pretrain_class_name = ["AniGenElasticSLatEncoder", skin_ae_name]
|
| 127 |
+
self.pretrain_ckpt_filter_prefix = {skin_ae_name: "skin_encoder"}
|
| 128 |
+
self.resolution = resolution
|
| 129 |
+
|
| 130 |
+
self.latent_channels = latent_channels
|
| 131 |
+
self.latent_channels_skl = latent_channels_skl
|
| 132 |
+
self.latent_channels_vertskin = latent_channels_vertskin
|
| 133 |
+
|
| 134 |
+
skin_encoder_config['use_fp16'] = use_fp16
|
| 135 |
+
self.skin_encoder = SkinEncoder(**skin_encoder_config)
|
| 136 |
+
self.encode_upsampled_skin_feat = encode_upsampled_skin_feat
|
| 137 |
+
self.in_layer_skin = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels * (8 if encode_upsampled_skin_feat else 1), channels_out=model_channels_skin)
|
| 138 |
+
|
| 139 |
+
self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=4 if self.jp_hyper_continuous else 3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True)
|
| 140 |
+
self.root_embedding = nn.Parameter(torch.zeros(1, self.pos_embedder_fourier.out_dim))
|
| 141 |
+
|
| 142 |
+
# Channel Balance
|
| 143 |
+
self.in_layer_jp_skl = FeedForwardNet(channels=2 * self.pos_embedder_fourier.out_dim, channels_out=model_channels_skl//4)
|
| 144 |
+
self.in_layer_skin_skl = FeedForwardNet(channels=self.skin_encoder.skin_feat_channels, channels_out=model_channels_skl-(model_channels_skl//4))
|
| 145 |
+
|
| 146 |
+
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
|
| 147 |
+
if self.latent_denoising:
|
| 148 |
+
self.out_layer_skl = sp.SparseLinear(model_channels_skl, latent_channels_skl)
|
| 149 |
+
self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, latent_channels_vertskin)
|
| 150 |
+
else:
|
| 151 |
+
self.out_layer_skl = sp.SparseLinear(model_channels_skl, 2 * latent_channels_skl)
|
| 152 |
+
self.out_layer_vertskin = sp.SparseLinear(model_channels_skin, 2 * latent_channels_vertskin)
|
| 153 |
+
|
| 154 |
+
self.initialize_weights()
|
| 155 |
+
if use_fp16:
|
| 156 |
+
self.convert_to_fp16()
|
| 157 |
+
else:
|
| 158 |
+
self.convert_to_fp32()
|
| 159 |
+
|
| 160 |
+
if 'all' in modules_to_freeze:
|
| 161 |
+
modules_to_freeze = list(set([k.split('.')[0] for k in self.state_dict().keys()]))
|
| 162 |
+
print(f"\033[93mFreezing all modules: {modules_to_freeze}\033[0m")
|
| 163 |
+
if self.use_pretrain_branch and self.freeze_pretrain_branch:
|
| 164 |
+
for module in modules_to_freeze:
|
| 165 |
+
if hasattr(self, module):
|
| 166 |
+
mod = getattr(self, module)
|
| 167 |
+
if isinstance(mod, nn.ModuleList):
|
| 168 |
+
for m in mod:
|
| 169 |
+
for name, param in m.named_parameters():
|
| 170 |
+
if 'lora' not in name:
|
| 171 |
+
param.requires_grad = False
|
| 172 |
+
elif isinstance(mod, nn.Module):
|
| 173 |
+
for name, param in mod.named_parameters():
|
| 174 |
+
if 'lora' not in name:
|
| 175 |
+
param.requires_grad = False
|
| 176 |
+
elif isinstance(mod, torch.Tensor):
|
| 177 |
+
if mod.requires_grad:
|
| 178 |
+
mod.requires_grad = False
|
| 179 |
+
|
| 180 |
+
def initialize_weights(self) -> None:
|
| 181 |
+
super().initialize_weights()
|
| 182 |
+
# Zero-out output layers:
|
| 183 |
+
nn.init.constant_(self.out_layer.weight, 0)
|
| 184 |
+
nn.init.constant_(self.out_layer.bias, 0)
|
| 185 |
+
|
| 186 |
+
def skeleton_embedding(self, x, x_skl, joints_list, parents_list, skin_list, gt_meshes, bvh_list=None):
|
| 187 |
+
res = self.resolution
|
| 188 |
+
feats_new = []
|
| 189 |
+
feats_skl_new = []
|
| 190 |
+
coords_new = []
|
| 191 |
+
coords_skl_new = []
|
| 192 |
+
|
| 193 |
+
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
|
| 194 |
+
joints_pos_list = []
|
| 195 |
+
|
| 196 |
+
for i in range(len(joints_list)):
|
| 197 |
+
parent_idx = parents_list[i].clone()
|
| 198 |
+
|
| 199 |
+
coords_new.append(x[i].coords)
|
| 200 |
+
coords_skl_new.append(x_skl[i].coords)
|
| 201 |
+
coords_new[-1][:, 0] = i
|
| 202 |
+
coords_skl_new[-1][:, 0] = i
|
| 203 |
+
|
| 204 |
+
v_pos = (x[i].coords[:, 1:4] + 0.5) / res - 0.5
|
| 205 |
+
v_pos_skl = (x_skl[i].coords[:, 1:4] + 0.5) / res - 0.5
|
| 206 |
+
dist_nn_12, joints_nn_idx, _ = knn_points(v_pos_skl[None], joints_list[i][None], K=2, norm=2, return_nn=False)
|
| 207 |
+
joints_nn_idx = joints_nn_idx[0, :, 0]
|
| 208 |
+
|
| 209 |
+
# Skeleton positional embedding
|
| 210 |
+
joints_pos = joints_list[i][joints_nn_idx] - (v_pos_skl if self.jp_residual_fields else 0)
|
| 211 |
+
parents_pos = joints_list[i][parent_idx[joints_nn_idx]] - (v_pos_skl if self.jp_residual_fields else 0)
|
| 212 |
+
if self.jp_hyper_continuous:
|
| 213 |
+
factor = (1 - (dist_nn_12[0, :, 0:1] / (dist_nn_12[0, :, 1:2] + 1e-8)).clamp(max=1.0))
|
| 214 |
+
joints_pos = torch.cat([joints_pos, factor], dim=-1)
|
| 215 |
+
parents_pos = torch.cat([parents_pos, factor], dim=-1)
|
| 216 |
+
joints_pos_embed = self.pos_embedder_fourier(joints_pos)
|
| 217 |
+
parents_pos_embed = self.pos_embedder_fourier(parents_pos)
|
| 218 |
+
parents_pos_embed = torch.where(parent_idx[joints_nn_idx][:, None] == -1, self.root_embedding.expand_as(parents_pos_embed), parents_pos_embed)
|
| 219 |
+
jp_pos_embed_nn = torch.cat([joints_pos_embed, parents_pos_embed], dim=-1)
|
| 220 |
+
jp_pos_embed_nn = self.in_layer_jp_skl(jp_pos_embed_nn)
|
| 221 |
+
|
| 222 |
+
# Skeleton skin embedding
|
| 223 |
+
j_skin_embed_nn = joint_skin_embeds[i][joints_nn_idx]
|
| 224 |
+
j_skin_embed_nn = self.in_layer_skin_skl(j_skin_embed_nn)
|
| 225 |
+
|
| 226 |
+
# Concatenate
|
| 227 |
+
jp_skl_embed = torch.cat([jp_pos_embed_nn, j_skin_embed_nn], dim=-1)
|
| 228 |
+
feats_skl_new.append(jp_skl_embed)
|
| 229 |
+
|
| 230 |
+
if self.encode_upsampled_skin_feat:
|
| 231 |
+
# Create 8 sub-voxel points
|
| 232 |
+
offsets = torch.tensor([
|
| 233 |
+
[-1, -1, -1], [-1, -1, 1], [-1, 1, -1], [-1, 1, 1],
|
| 234 |
+
[1, -1, -1], [1, -1, 1], [1, 1, -1], [1, 1, 1]
|
| 235 |
+
], device=v_pos.device, dtype=v_pos.dtype) * (0.25 / res)
|
| 236 |
+
query_pos = v_pos.unsqueeze(1) + offsets.unsqueeze(0) # (N, 8, 3)
|
| 237 |
+
query_pos_flat = query_pos.view(-1, 3)
|
| 238 |
+
else:
|
| 239 |
+
query_pos_flat = v_pos
|
| 240 |
+
|
| 241 |
+
if bvh_list is not None:
|
| 242 |
+
bvh = bvh_list[i].to(v_pos.device)
|
| 243 |
+
_, face_id, uvw = bvh.unsigned_distance(query_pos_flat, return_uvw=True)
|
| 244 |
+
uvw = uvw.clamp(min=0.0)
|
| 245 |
+
uvw_sum = uvw.sum(dim=-1, keepdim=True).clamp_min(1e-6)
|
| 246 |
+
uvw = uvw / uvw_sum
|
| 247 |
+
face_id = gt_meshes[i]['faces'][face_id]
|
| 248 |
+
voxel_skin_embeds = (vert_skin_embeds[i][face_id] * uvw[..., None]).sum(1)
|
| 249 |
+
else:
|
| 250 |
+
gt_mesh_verts = gt_meshes[i]['vertices']
|
| 251 |
+
_, mesh_nn_idx, _ = knn_points(query_pos_flat[None], gt_mesh_verts[None], K=1, norm=2, return_nn=False)
|
| 252 |
+
mesh_nn_idx = mesh_nn_idx[0, :, 0]
|
| 253 |
+
voxel_skin_embeds = vert_skin_embeds[i][mesh_nn_idx]
|
| 254 |
+
|
| 255 |
+
voxel_skin_embeds = voxel_skin_embeds.view(v_pos.shape[0], -1)
|
| 256 |
+
voxel_skin_embeds = self.in_layer_skin(voxel_skin_embeds)
|
| 257 |
+
feats_new.append(voxel_skin_embeds)
|
| 258 |
+
joints_pos_list.append(joints_pos)
|
| 259 |
+
|
| 260 |
+
x_new = sp.SparseTensor(coords=torch.cat(coords_new, dim=0), feats=torch.cat(feats_new, dim=0))
|
| 261 |
+
x_skl_new = sp.SparseTensor(coords=torch.cat(coords_skl_new, dim=0), feats=torch.cat(feats_skl_new, dim=0))
|
| 262 |
+
|
| 263 |
+
return x_new, x_skl_new, joint_skin_embeds, vert_skin_embeds, joints_pos_list
|
| 264 |
+
|
| 265 |
+
def encode_sample(self, x: sp.SparseTensor, out_layer: sp.SparseLinear, sample_posterior: bool = True, latent_denoising: bool = False):
|
| 266 |
+
x = x.type(torch.float32)
|
| 267 |
+
x = x.replace(F.layer_norm(x.feats, x.feats.shape[-1:]))
|
| 268 |
+
x = out_layer(x)
|
| 269 |
+
if latent_denoising:
|
| 270 |
+
if self.normalize_latent:
|
| 271 |
+
x = x.replace(nn.functional.normalize(x.feats, dim=-1) * self.normalize_scale)
|
| 272 |
+
mean, logvar = x.feats, torch.zeros_like(x.feats)
|
| 273 |
+
else:
|
| 274 |
+
mean, logvar = x.feats.chunk(2, dim=-1)
|
| 275 |
+
if sample_posterior and not latent_denoising:
|
| 276 |
+
std = torch.exp(0.5 * logvar)
|
| 277 |
+
z = mean + std * torch.randn_like(std)
|
| 278 |
+
else:
|
| 279 |
+
z = mean
|
| 280 |
+
z = x.replace(z)
|
| 281 |
+
if latent_denoising:
|
| 282 |
+
mean = mean.detach()
|
| 283 |
+
return z, mean, logvar
|
| 284 |
+
|
| 285 |
+
def forward(self, x: sp.SparseTensor, x_skl: sp.SparseTensor, sample_posterior=True, return_raw=False, return_skin_encoded=False, **kwargs):
|
| 286 |
+
x_skin, x_skl, joint_skin_embeds, vert_skin_embeds, joints_pos = self.skeleton_embedding(x, x_skl, kwargs.get('gt_joints'), kwargs.get('gt_parents'), kwargs.get('gt_skin'), kwargs.get('gt_mesh'), kwargs.get('bvh_list', None))
|
| 287 |
+
h, h_skl, h_skin = super().forward(x, x_skl, x_skin)
|
| 288 |
+
|
| 289 |
+
z, mean, logvar = self.encode_sample(h, self.out_layer, sample_posterior, latent_denoising=False)
|
| 290 |
+
z_skl, mean_skl, logvar_skl = self.encode_sample(h_skl, self.out_layer_skl, sample_posterior, latent_denoising=self.latent_denoising)
|
| 291 |
+
z_skin, mean_skin, logvar_skin = self.encode_sample(h_skin, self.out_layer_vertskin, sample_posterior, latent_denoising=self.latent_denoising)
|
| 292 |
+
|
| 293 |
+
z = z.replace(torch.cat([z.feats, z_skin.feats], dim=-1))
|
| 294 |
+
mean, logvar = torch.cat([mean, mean_skin], dim=-1), torch.cat([logvar, logvar_skin], dim=-1)
|
| 295 |
+
|
| 296 |
+
if not return_skin_encoded:
|
| 297 |
+
# Ordinary return without skin encoded features
|
| 298 |
+
if return_raw:
|
| 299 |
+
return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
|
| 300 |
+
else:
|
| 301 |
+
return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos
|
| 302 |
+
else:
|
| 303 |
+
# Return skin encoded features as well for checking
|
| 304 |
+
if return_raw:
|
| 305 |
+
return z, mean, logvar, z_skl, mean_skl, logvar_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
|
| 306 |
+
else:
|
| 307 |
+
return z, z_skl, joint_skin_embeds, vert_skin_embeds, joints_pos, x_skin, x_skl
|
| 308 |
+
|
| 309 |
+
def encode_skin(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None):
|
| 310 |
+
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
|
| 311 |
+
return joint_skin_embeds, vert_skin_embeds
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class AniGenElasticSLatEncoder(SparseTransformerElasticMixin, AniGenSLatEncoder):
|
| 315 |
+
"""
|
| 316 |
+
SLat VAE encoder with elastic memory management.
|
| 317 |
+
Used for training with low VRAM.
|
| 318 |
+
"""
|
anigen/models/structured_latent_vae/base.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
| 5 |
+
from ...modules import sparse as sp
|
| 6 |
+
from ...modules.transformer import AbsolutePositionEmbedder
|
| 7 |
+
from ...modules.sparse.transformer import SparseTransformerBlock
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def block_attn_config(self):
|
| 11 |
+
"""
|
| 12 |
+
Return the attention configuration of the model.
|
| 13 |
+
"""
|
| 14 |
+
for i in range(self.num_blocks):
|
| 15 |
+
if self.attn_mode == "shift_window":
|
| 16 |
+
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
| 17 |
+
elif self.attn_mode == "shift_sequence":
|
| 18 |
+
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
| 19 |
+
elif self.attn_mode == "shift_order":
|
| 20 |
+
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
| 21 |
+
elif self.attn_mode == "full":
|
| 22 |
+
yield "full", None, None, None, None
|
| 23 |
+
elif self.attn_mode == "swin":
|
| 24 |
+
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SparseTransformerBase(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Sparse Transformer without output layers.
|
| 30 |
+
Serve as the base class for encoder and decoder.
|
| 31 |
+
"""
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
in_channels: int,
|
| 35 |
+
model_channels: int,
|
| 36 |
+
num_blocks: int,
|
| 37 |
+
num_heads: Optional[int] = None,
|
| 38 |
+
num_head_channels: Optional[int] = 64,
|
| 39 |
+
mlp_ratio: float = 4.0,
|
| 40 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 41 |
+
window_size: Optional[int] = None,
|
| 42 |
+
pe_mode: Literal["ape", "rope"] = "ape",
|
| 43 |
+
use_fp16: bool = False,
|
| 44 |
+
use_checkpoint: bool = False,
|
| 45 |
+
qk_rms_norm: bool = False,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.in_channels = in_channels
|
| 49 |
+
self.model_channels = model_channels
|
| 50 |
+
self.num_blocks = num_blocks
|
| 51 |
+
self.window_size = window_size
|
| 52 |
+
self.num_heads = num_heads or model_channels // num_head_channels
|
| 53 |
+
self.mlp_ratio = mlp_ratio
|
| 54 |
+
self.attn_mode = attn_mode
|
| 55 |
+
self.pe_mode = pe_mode
|
| 56 |
+
self.use_fp16 = use_fp16
|
| 57 |
+
self.use_checkpoint = use_checkpoint
|
| 58 |
+
self.qk_rms_norm = qk_rms_norm
|
| 59 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 60 |
+
|
| 61 |
+
if pe_mode == "ape":
|
| 62 |
+
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
| 63 |
+
|
| 64 |
+
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
| 65 |
+
self.blocks = nn.ModuleList([
|
| 66 |
+
SparseTransformerBlock(
|
| 67 |
+
model_channels,
|
| 68 |
+
num_heads=self.num_heads,
|
| 69 |
+
mlp_ratio=self.mlp_ratio,
|
| 70 |
+
attn_mode=attn_mode,
|
| 71 |
+
window_size=window_size,
|
| 72 |
+
shift_sequence=shift_sequence,
|
| 73 |
+
shift_window=shift_window,
|
| 74 |
+
serialize_mode=serialize_mode,
|
| 75 |
+
use_checkpoint=self.use_checkpoint,
|
| 76 |
+
use_rope=(pe_mode == "rope"),
|
| 77 |
+
qk_rms_norm=self.qk_rms_norm,
|
| 78 |
+
)
|
| 79 |
+
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
| 80 |
+
])
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def device(self) -> torch.device:
|
| 84 |
+
"""
|
| 85 |
+
Return the device of the model.
|
| 86 |
+
"""
|
| 87 |
+
return next(self.parameters()).device
|
| 88 |
+
|
| 89 |
+
def convert_to_fp16(self) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Convert the torso of the model to float16.
|
| 92 |
+
"""
|
| 93 |
+
self.blocks.apply(convert_module_to_f16)
|
| 94 |
+
|
| 95 |
+
def convert_to_fp32(self) -> None:
|
| 96 |
+
"""
|
| 97 |
+
Convert the torso of the model to float32.
|
| 98 |
+
"""
|
| 99 |
+
self.blocks.apply(convert_module_to_f32)
|
| 100 |
+
|
| 101 |
+
def initialize_weights(self) -> None:
|
| 102 |
+
# Initialize transformer layers:
|
| 103 |
+
def _basic_init(module):
|
| 104 |
+
if isinstance(module, nn.Linear):
|
| 105 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 106 |
+
if module.bias is not None:
|
| 107 |
+
nn.init.constant_(module.bias, 0)
|
| 108 |
+
self.apply(_basic_init)
|
| 109 |
+
|
| 110 |
+
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
| 111 |
+
h = self.input_layer(x)
|
| 112 |
+
if self.pe_mode == "ape":
|
| 113 |
+
h = h + self.pos_embedder(x.coords[:, 1:])
|
| 114 |
+
h = h.type(self.dtype)
|
| 115 |
+
for block in self.blocks:
|
| 116 |
+
h = block(h)
|
| 117 |
+
return h
|
anigen/models/structured_latent_vae/skin_models.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import *
|
| 4 |
+
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
| 5 |
+
from ...modules.transformer import TransformerBlock, FeedForwardNet
|
| 6 |
+
from .anigen_base import FreqPositionalEmbedder, TransformerCrossBlock
|
| 7 |
+
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Embedder(nn.Module):
|
| 11 |
+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int=None, depth: int = 4, mlp_ratio: float = 4.0, jp_embed_attn: bool = True):
|
| 12 |
+
super().__init__()
|
| 13 |
+
hidden_dim = out_dim if hidden_dim is None else hidden_dim
|
| 14 |
+
self.jp_embed_attn = jp_embed_attn
|
| 15 |
+
self.in_layer = FeedForwardNet(channels=in_dim, out_channels=hidden_dim, mlp_ratio=mlp_ratio)
|
| 16 |
+
if self.jp_embed_attn:
|
| 17 |
+
self.blocks = nn.ModuleList([TransformerBlock(hidden_dim, num_heads=8, attn_mode='full') for _ in range(depth)])
|
| 18 |
+
for block in self.blocks:
|
| 19 |
+
block.to(torch.float16)
|
| 20 |
+
else:
|
| 21 |
+
self.blocks = nn.ModuleList([FeedForwardNet(channels=hidden_dim, out_channels=hidden_dim, mlp_ratio=mlp_ratio) for _ in range(depth)])
|
| 22 |
+
self.out_layer = nn.Linear(hidden_dim, out_dim)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
x = self.in_layer(x)
|
| 26 |
+
h = x
|
| 27 |
+
for block in self.blocks:
|
| 28 |
+
h = block(h[None].type(torch.float16))[0] if self.jp_embed_attn else block(h) + x
|
| 29 |
+
h = self.out_layer(h.type(x.dtype))
|
| 30 |
+
return h
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SkinEncoder(nn.Module):
|
| 34 |
+
def __init__(self, skin_feat_channels: int = 8, skl_pos_embed_freq: int = 10, jp_embedder_config: Optional[Dict[str, Any]] = {}, jp_embed_dim: int = 128, relative_pe=True, vert_feat_is_linear=True, normalize_feat=True, **kwargs):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.skin_feat_channels = skin_feat_channels
|
| 37 |
+
self.skl_pos_embed_freq = skl_pos_embed_freq
|
| 38 |
+
self.jp_embedder_config = jp_embedder_config
|
| 39 |
+
|
| 40 |
+
self.pos_embedder_fourier = FreqPositionalEmbedder(in_dim=3, max_freq_log2=self.skl_pos_embed_freq, num_freqs=self.skl_pos_embed_freq, include_input=True)
|
| 41 |
+
self.pos_embedder_linear = nn.Linear(self.pos_embedder_fourier.out_dim, jp_embed_dim)
|
| 42 |
+
self.root_embedding = nn.Parameter(torch.zeros(1, jp_embed_dim))
|
| 43 |
+
self.joint_embedder = Embedder(in_dim=2 * jp_embed_dim, out_dim=jp_embed_dim, **self.jp_embedder_config)
|
| 44 |
+
self.out_layer_vert = FeedForwardNet(channels=jp_embed_dim, out_channels=skin_feat_channels)
|
| 45 |
+
self.out_layer_joint = FeedForwardNet(channels=jp_embed_dim, out_channels=skin_feat_channels)
|
| 46 |
+
self.relative_pe = relative_pe
|
| 47 |
+
self.vert_feat_is_linear = vert_feat_is_linear
|
| 48 |
+
self.normalize_feat = normalize_feat
|
| 49 |
+
|
| 50 |
+
def forward(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]=None):
|
| 51 |
+
vert_skin_embeds = [] if skin_list is not None else None
|
| 52 |
+
joint_skin_embeds = []
|
| 53 |
+
for i in range(len(joints_list)):
|
| 54 |
+
parent_idx = parents_list[i].clone()
|
| 55 |
+
joints = joints_list[i]
|
| 56 |
+
if self.relative_pe:
|
| 57 |
+
joints = joints - torch.cat([joints, joints[:1]])[parent_idx]
|
| 58 |
+
joints_pos_embed = self.pos_embedder_linear(self.pos_embedder_fourier(joints))
|
| 59 |
+
joints_pos_embed = torch.cat([joints_pos_embed, self.root_embedding], dim=0)
|
| 60 |
+
parents_pos_embed = joints_pos_embed[parent_idx]
|
| 61 |
+
jp_pos_embed = torch.cat([joints_pos_embed[:-1], parents_pos_embed], dim=-1)
|
| 62 |
+
joints_embed = self.joint_embedder(jp_pos_embed)
|
| 63 |
+
if self.normalize_feat:
|
| 64 |
+
joints_embed = torch.nn.functional.normalize(joints_embed, dim=-1)
|
| 65 |
+
if skin_list is not None:
|
| 66 |
+
vert_skin = skin_list[i]
|
| 67 |
+
if self.vert_feat_is_linear:
|
| 68 |
+
joints_embed_for_vert = self.out_layer_vert(joints_embed)
|
| 69 |
+
vert_skin_embed = vert_skin @ joints_embed_for_vert
|
| 70 |
+
else:
|
| 71 |
+
vert_skin_embed = vert_skin @ joints_embed
|
| 72 |
+
vert_skin_embed = self.out_layer_vert(vert_skin_embed)
|
| 73 |
+
vert_skin_embeds.append(vert_skin_embed)
|
| 74 |
+
joints_embed = self.out_layer_joint(joints_embed)
|
| 75 |
+
joint_skin_embeds.append(joints_embed)
|
| 76 |
+
return joint_skin_embeds, vert_skin_embeds
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def clamp_with_grad(x, min, max):
|
| 80 |
+
return x + (x.clamp(min, max) - x).detach()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TreeTransformerSkinDecoder(nn.Module):
|
| 84 |
+
# The principles of the tree transformer skinning model:
|
| 85 |
+
# 1. joint features are related to the tree structure, since the decoding process is skeleton-agnostic.
|
| 86 |
+
# 2. decode the skinning weights directly, hoping the transformer can handle the skinning assignment.
|
| 87 |
+
# It's a pure learning-based method.
|
| 88 |
+
def __init__(self,
|
| 89 |
+
skin_feat_channels: int,
|
| 90 |
+
model_channels: int=512,
|
| 91 |
+
num_heads=4,
|
| 92 |
+
num_blocks=4,
|
| 93 |
+
vert_cross_blocks_num: int = 1,
|
| 94 |
+
use_fp16: bool = False):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.skin_feat_channels = skin_feat_channels
|
| 97 |
+
self.model_channels = model_channels
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.num_blocks = num_blocks
|
| 100 |
+
self.root_features = nn.Parameter(torch.zeros([1, skin_feat_channels]), requires_grad=True)
|
| 101 |
+
self.input_layer_vertex = nn.Linear(skin_feat_channels, model_channels)
|
| 102 |
+
self.input_layer_skin = nn.Linear(skin_feat_channels*2, model_channels)
|
| 103 |
+
assert vert_cross_blocks_num <= num_blocks, f"vert_cross_blocks_num should be less than or equal to num_blocks, got {vert_cross_blocks_num} and {num_blocks}."
|
| 104 |
+
self.vert_cross_blocks_num = vert_cross_blocks_num
|
| 105 |
+
self.blocks_vertex = nn.ModuleList([
|
| 106 |
+
TransformerCrossBlock(
|
| 107 |
+
channels=model_channels,
|
| 108 |
+
ctx_channels=model_channels,
|
| 109 |
+
num_heads=num_heads,
|
| 110 |
+
mlp_ratio=4.0,
|
| 111 |
+
attn_mode="full",
|
| 112 |
+
no_self=True)
|
| 113 |
+
for _ in range(self.vert_cross_blocks_num)
|
| 114 |
+
] + [
|
| 115 |
+
FeedForwardNet(
|
| 116 |
+
channels=model_channels,
|
| 117 |
+
mlp_ratio=4.0,
|
| 118 |
+
out_channels=model_channels,
|
| 119 |
+
)
|
| 120 |
+
for _ in range(num_blocks - self.vert_cross_blocks_num)
|
| 121 |
+
])
|
| 122 |
+
self.blocks_skin = nn.ModuleList([
|
| 123 |
+
TransformerBlock(
|
| 124 |
+
channels=model_channels,
|
| 125 |
+
num_heads=num_heads,
|
| 126 |
+
mlp_ratio=4.0,
|
| 127 |
+
attn_mode="full")
|
| 128 |
+
for _ in range(num_blocks)
|
| 129 |
+
])
|
| 130 |
+
self.out_layer_vertex = nn.Sequential(
|
| 131 |
+
nn.Linear(model_channels, model_channels*4),
|
| 132 |
+
nn.GELU(approximate="tanh"),
|
| 133 |
+
nn.Linear(model_channels*4, model_channels+1),
|
| 134 |
+
)
|
| 135 |
+
self.out_layer_skin = nn.Sequential(
|
| 136 |
+
nn.Linear(model_channels, model_channels*4),
|
| 137 |
+
nn.GELU(approximate="tanh"),
|
| 138 |
+
nn.Linear(model_channels*4, model_channels),
|
| 139 |
+
)
|
| 140 |
+
self.temp_activation = nn.ELU(alpha=1.0)
|
| 141 |
+
self.dtype = torch.float16 if use_fp16 else torch.float32
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def device(self) -> torch.device:
|
| 145 |
+
"""
|
| 146 |
+
Return the device of the model.
|
| 147 |
+
"""
|
| 148 |
+
return next(self.parameters()).device
|
| 149 |
+
|
| 150 |
+
def convert_to_fp16(self) -> None:
|
| 151 |
+
"""
|
| 152 |
+
Convert the torso of the model to float16.
|
| 153 |
+
"""
|
| 154 |
+
self.blocks_vertex.apply(convert_module_to_f16)
|
| 155 |
+
self.blocks_skin.apply(convert_module_to_f16)
|
| 156 |
+
|
| 157 |
+
def convert_to_fp32(self) -> None:
|
| 158 |
+
"""
|
| 159 |
+
Convert the torso of the model to float32.
|
| 160 |
+
"""
|
| 161 |
+
self.blocks_vertex.apply(convert_module_to_f32)
|
| 162 |
+
self.blocks_skin.apply(convert_module_to_f32)
|
| 163 |
+
|
| 164 |
+
def initialize_weights(self) -> None:
|
| 165 |
+
# Initialize transformer layers:
|
| 166 |
+
def _basic_init(module):
|
| 167 |
+
if isinstance(module, nn.Linear):
|
| 168 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 169 |
+
if module.bias is not None:
|
| 170 |
+
nn.init.constant_(module.bias, 0)
|
| 171 |
+
self.apply(_basic_init)
|
| 172 |
+
|
| 173 |
+
def forward(self, vertex_features, joint_features, parents) -> torch.Tensor:
|
| 174 |
+
j_num = joint_features.shape[1]
|
| 175 |
+
h_v = vertex_features
|
| 176 |
+
h_v = self.input_layer_vertex(h_v)
|
| 177 |
+
h_j = joint_features
|
| 178 |
+
h_j = torch.cat([h_j, self.root_features[None]], dim=1)
|
| 179 |
+
parents = torch.where(parents < 0, torch.ones_like(parents)*j_num, parents)
|
| 180 |
+
h_j = torch.cat([h_j[:, :-1], h_j[:, parents[0]]], dim=-1)
|
| 181 |
+
h_j = self.input_layer_skin(h_j)
|
| 182 |
+
h_v = h_v.type(self.dtype)
|
| 183 |
+
h_j = h_j.type(self.dtype)
|
| 184 |
+
blocks_num = len(self.blocks_vertex)
|
| 185 |
+
for idx, block_v, block_j in zip(range(blocks_num), self.blocks_vertex, self.blocks_skin):
|
| 186 |
+
f_v, f_j = h_v, h_j
|
| 187 |
+
h_v = block_v(f_v, f_j) if idx < self.vert_cross_blocks_num else block_v(f_v)
|
| 188 |
+
h_j = block_j(f_j)
|
| 189 |
+
h_v = h_v.type(vertex_features.dtype)
|
| 190 |
+
h_j = h_j.type(joint_features.dtype)
|
| 191 |
+
h_v = self.out_layer_vertex(h_v)
|
| 192 |
+
h_j = self.out_layer_skin(h_j)
|
| 193 |
+
h_v, inv_temp = h_v[..., :-1], h_v[..., -1].unsqueeze(-1)
|
| 194 |
+
inv_temp = self.temp_activation(inv_temp) + self.temp_activation.alpha + 1.0
|
| 195 |
+
skin_weights = torch.einsum("nac,nbc->nab", h_v, h_j)
|
| 196 |
+
skin_weights = torch.softmax(skin_weights * inv_temp, dim=-1)
|
| 197 |
+
return skin_weights
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
SKIN_MODEL_DICT = {'tree': TreeTransformerSkinDecoder}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class SkinAutoEncoder(nn.Module):
|
| 204 |
+
def __init__(self, encoder_config: Dict[str, Any], decoder_config: Dict[str, Any], use_fp16: bool = False):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.skin_encoder = SkinEncoder(**encoder_config)
|
| 207 |
+
decoder_config['use_fp16'] = use_fp16
|
| 208 |
+
self.skin_decoder = SKIN_MODEL_DICT[decoder_config.pop('model_type')](**decoder_config)
|
| 209 |
+
|
| 210 |
+
self.initialize_weights()
|
| 211 |
+
if use_fp16:
|
| 212 |
+
self.convert_to_fp16()
|
| 213 |
+
else:
|
| 214 |
+
self.convert_to_fp32()
|
| 215 |
+
|
| 216 |
+
def convert_to_fp16(self) -> None:
|
| 217 |
+
self.skin_decoder.convert_to_fp16()
|
| 218 |
+
|
| 219 |
+
def convert_to_fp32(self) -> None:
|
| 220 |
+
self.skin_decoder.convert_to_fp32()
|
| 221 |
+
|
| 222 |
+
def initialize_weights(self) -> None:
|
| 223 |
+
# Initialize transformer layers:
|
| 224 |
+
def _basic_init(module):
|
| 225 |
+
if isinstance(module, nn.Linear):
|
| 226 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 227 |
+
if module.bias is not None:
|
| 228 |
+
nn.init.constant_(module.bias, 0)
|
| 229 |
+
self.apply(_basic_init)
|
| 230 |
+
|
| 231 |
+
def encode(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]):
|
| 232 |
+
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
|
| 233 |
+
return joint_skin_embeds, vert_skin_embeds
|
| 234 |
+
|
| 235 |
+
def decode(self, vertex_features, joint_features, parents) -> torch.Tensor:
|
| 236 |
+
skin_weights = self.skin_decoder(vertex_features, joint_features, parents)
|
| 237 |
+
return skin_weights
|
| 238 |
+
|
| 239 |
+
def forward(self, joints_list: List[torch.Tensor], parents_list: List[torch.Tensor], skin_list: List[torch.Tensor]):
|
| 240 |
+
joint_skin_embeds, vert_skin_embeds = self.skin_encoder(joints_list, parents_list, skin_list)
|
| 241 |
+
skin_pred_list = []
|
| 242 |
+
for i in range(len(joints_list)):
|
| 243 |
+
skin_pred = self.skin_decoder(vert_skin_embeds[i][None], joint_skin_embeds[i][None], parents_list[i][None])
|
| 244 |
+
skin_pred_list.append(skin_pred[0])
|
| 245 |
+
return skin_pred_list, joint_skin_embeds, vert_skin_embeds
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class AniGenElasticSLatEncoderGamma(SparseTransformerElasticMixin, SkinAutoEncoder):
|
| 249 |
+
"""
|
| 250 |
+
SLat VAE encoder with elastic memory management.
|
| 251 |
+
Used for training with low VRAM.
|
| 252 |
+
"""
|
anigen/modules/attention/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
BACKEND = 'flash_attn'
|
| 4 |
+
DEBUG = False
|
| 5 |
+
|
| 6 |
+
def __from_env():
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
global BACKEND
|
| 10 |
+
global DEBUG
|
| 11 |
+
|
| 12 |
+
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
| 13 |
+
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
| 14 |
+
|
| 15 |
+
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 16 |
+
BACKEND = env_attn_backend
|
| 17 |
+
if env_sttn_debug is not None:
|
| 18 |
+
DEBUG = env_sttn_debug == '1'
|
| 19 |
+
|
| 20 |
+
print(f"[ATTENTION] Using backend: {BACKEND}")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__from_env()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def set_backend(backend: Literal['xformers', 'flash_attn']):
|
| 27 |
+
global BACKEND
|
| 28 |
+
BACKEND = backend
|
| 29 |
+
|
| 30 |
+
def set_debug(debug: bool):
|
| 31 |
+
global DEBUG
|
| 32 |
+
DEBUG = debug
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
from .full_attn import *
|
| 36 |
+
from .modules import *
|
anigen/modules/attention/full_attn.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from . import DEBUG, BACKEND
|
| 5 |
+
|
| 6 |
+
if BACKEND == 'xformers':
|
| 7 |
+
import xformers.ops as xops
|
| 8 |
+
elif BACKEND == 'flash_attn':
|
| 9 |
+
import flash_attn
|
| 10 |
+
elif BACKEND == 'sdpa':
|
| 11 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 12 |
+
elif BACKEND == 'naive':
|
| 13 |
+
pass
|
| 14 |
+
else:
|
| 15 |
+
raise ValueError(f"Unknown attention backend: {BACKEND}")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'scaled_dot_product_attention',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _naive_sdpa(q, k, v):
|
| 24 |
+
"""
|
| 25 |
+
Naive implementation of scaled dot product attention.
|
| 26 |
+
"""
|
| 27 |
+
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 28 |
+
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 29 |
+
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 30 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 31 |
+
attn_weight = q @ k.transpose(-2, -1) * scale_factor
|
| 32 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 33 |
+
out = attn_weight @ v
|
| 34 |
+
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@overload
|
| 39 |
+
def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
Apply scaled dot product attention.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
|
| 45 |
+
"""
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
@overload
|
| 49 |
+
def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Apply scaled dot product attention.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
|
| 55 |
+
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
|
| 56 |
+
"""
|
| 57 |
+
...
|
| 58 |
+
|
| 59 |
+
@overload
|
| 60 |
+
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Apply scaled dot product attention.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
|
| 66 |
+
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
|
| 67 |
+
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
|
| 68 |
+
|
| 69 |
+
Note:
|
| 70 |
+
k and v are assumed to have the same coordinate map.
|
| 71 |
+
"""
|
| 72 |
+
...
|
| 73 |
+
|
| 74 |
+
def scaled_dot_product_attention(*args, **kwargs):
|
| 75 |
+
arg_names_dict = {
|
| 76 |
+
1: ['qkv'],
|
| 77 |
+
2: ['q', 'kv'],
|
| 78 |
+
3: ['q', 'k', 'v']
|
| 79 |
+
}
|
| 80 |
+
num_all_args = len(args) + len(kwargs)
|
| 81 |
+
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
|
| 82 |
+
for key in arg_names_dict[num_all_args][len(args):]:
|
| 83 |
+
assert key in kwargs, f"Missing argument {key}"
|
| 84 |
+
|
| 85 |
+
if num_all_args == 1:
|
| 86 |
+
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
| 87 |
+
assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
|
| 88 |
+
device = qkv.device
|
| 89 |
+
|
| 90 |
+
elif num_all_args == 2:
|
| 91 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 92 |
+
kv = args[1] if len(args) > 1 else kwargs['kv']
|
| 93 |
+
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
|
| 94 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
|
| 95 |
+
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
|
| 96 |
+
device = q.device
|
| 97 |
+
|
| 98 |
+
elif num_all_args == 3:
|
| 99 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 100 |
+
k = args[1] if len(args) > 1 else kwargs['k']
|
| 101 |
+
v = args[2] if len(args) > 2 else kwargs['v']
|
| 102 |
+
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
|
| 103 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
|
| 104 |
+
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
|
| 105 |
+
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
|
| 106 |
+
device = q.device
|
| 107 |
+
|
| 108 |
+
if BACKEND == 'xformers':
|
| 109 |
+
if num_all_args == 1:
|
| 110 |
+
q, k, v = qkv.unbind(dim=2)
|
| 111 |
+
elif num_all_args == 2:
|
| 112 |
+
k, v = kv.unbind(dim=2)
|
| 113 |
+
out = xops.memory_efficient_attention(q, k, v)
|
| 114 |
+
elif BACKEND == 'flash_attn':
|
| 115 |
+
if num_all_args == 1:
|
| 116 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv)
|
| 117 |
+
elif num_all_args == 2:
|
| 118 |
+
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
| 119 |
+
elif num_all_args == 3:
|
| 120 |
+
out = flash_attn.flash_attn_func(q, k, v)
|
| 121 |
+
elif BACKEND == 'sdpa':
|
| 122 |
+
if num_all_args == 1:
|
| 123 |
+
q, k, v = qkv.unbind(dim=2)
|
| 124 |
+
elif num_all_args == 2:
|
| 125 |
+
k, v = kv.unbind(dim=2)
|
| 126 |
+
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 127 |
+
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 128 |
+
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
| 129 |
+
out = sdpa(q, k, v) # [N, H, L, C]
|
| 130 |
+
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
| 131 |
+
elif BACKEND == 'naive':
|
| 132 |
+
if num_all_args == 1:
|
| 133 |
+
q, k, v = qkv.unbind(dim=2)
|
| 134 |
+
elif num_all_args == 2:
|
| 135 |
+
k, v = kv.unbind(dim=2)
|
| 136 |
+
out = _naive_sdpa(q, k, v)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unknown attention module: {BACKEND}")
|
| 139 |
+
|
| 140 |
+
return out
|
anigen/modules/attention/modules.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .full_attn import scaled_dot_product_attention
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiHeadRMSNorm(nn.Module):
|
| 9 |
+
def __init__(self, dim: int, heads: int):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.scale = dim ** 0.5
|
| 12 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RotaryPositionEmbedder(nn.Module):
|
| 19 |
+
def __init__(self, hidden_size: int, in_channels: int = 3):
|
| 20 |
+
super().__init__()
|
| 21 |
+
assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
|
| 22 |
+
self.hidden_size = hidden_size
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
self.freq_dim = hidden_size // in_channels // 2
|
| 25 |
+
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
| 26 |
+
self.freqs = 1.0 / (10000 ** self.freqs)
|
| 27 |
+
|
| 28 |
+
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
self.freqs = self.freqs.to(indices.device)
|
| 30 |
+
phases = torch.outer(indices, self.freqs)
|
| 31 |
+
phases = torch.polar(torch.ones_like(phases), phases)
|
| 32 |
+
return phases
|
| 33 |
+
|
| 34 |
+
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 36 |
+
x_rotated = x_complex * phases
|
| 37 |
+
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
| 38 |
+
return x_embed
|
| 39 |
+
|
| 40 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
q (sp.SparseTensor): [..., N, D] tensor of queries
|
| 44 |
+
k (sp.SparseTensor): [..., N, D] tensor of keys
|
| 45 |
+
indices (torch.Tensor): [..., N, C] tensor of spatial positions
|
| 46 |
+
"""
|
| 47 |
+
if indices is None:
|
| 48 |
+
indices = torch.arange(q.shape[-2], device=q.device)
|
| 49 |
+
if len(q.shape) > 2:
|
| 50 |
+
indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
|
| 51 |
+
|
| 52 |
+
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
| 53 |
+
if phases.shape[1] < self.hidden_size // 2:
|
| 54 |
+
phases = torch.cat([phases, torch.polar(
|
| 55 |
+
torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
|
| 56 |
+
torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
|
| 57 |
+
)], dim=-1)
|
| 58 |
+
q_embed = self._rotary_embedding(q, phases)
|
| 59 |
+
k_embed = self._rotary_embedding(k, phases)
|
| 60 |
+
return q_embed, k_embed
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LoRALinear(nn.Linear):
|
| 64 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True, rank: int = 4, lr_rate: float = 1.0):
|
| 65 |
+
super().__init__(in_features, out_features, bias)
|
| 66 |
+
self.rank = rank
|
| 67 |
+
self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
|
| 68 |
+
self.lora_B = nn.Parameter(torch.randn(rank, out_features) * 1e-2)
|
| 69 |
+
self.lr_rate = lr_rate
|
| 70 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
return super().forward(x) + (x @ self.lora_A) @ self.lora_B * self.lr_rate
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MultiHeadAttention(nn.Module):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
channels: int,
|
| 78 |
+
num_heads: int,
|
| 79 |
+
ctx_channels: Optional[int]=None,
|
| 80 |
+
type: Literal["self", "cross"] = "self",
|
| 81 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 82 |
+
window_size: Optional[int] = None,
|
| 83 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 84 |
+
qkv_bias: bool = True,
|
| 85 |
+
use_rope: bool = False,
|
| 86 |
+
qk_rms_norm: bool = False,
|
| 87 |
+
x_is_query: bool = False,
|
| 88 |
+
use_lora: bool = False,
|
| 89 |
+
lora_rank: int = 4,
|
| 90 |
+
lora_lr_rate: float = 1.0,
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
assert channels % num_heads == 0
|
| 94 |
+
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
| 95 |
+
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
|
| 96 |
+
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
| 97 |
+
|
| 98 |
+
if attn_mode == "windowed":
|
| 99 |
+
raise NotImplementedError("Windowed attention is not yet implemented")
|
| 100 |
+
|
| 101 |
+
self.channels = channels
|
| 102 |
+
self.head_dim = channels // num_heads
|
| 103 |
+
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self._type = type
|
| 106 |
+
self.attn_mode = attn_mode
|
| 107 |
+
self.window_size = window_size
|
| 108 |
+
self.shift_window = shift_window
|
| 109 |
+
self.use_rope = use_rope
|
| 110 |
+
self.qk_rms_norm = qk_rms_norm
|
| 111 |
+
|
| 112 |
+
if self._type == "self":
|
| 113 |
+
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) if not use_lora else LoRALinear(channels, channels * 3, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate)
|
| 114 |
+
else:
|
| 115 |
+
self.to_q = (lambda x: x) if x_is_query else (nn.Linear(channels, channels, bias=qkv_bias) if not use_lora else LoRALinear(channels, channels, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate))
|
| 116 |
+
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) if not use_lora else LoRALinear(self.ctx_channels, channels * 2, bias=qkv_bias, rank=lora_rank, lr_rate=lora_lr_rate)
|
| 117 |
+
|
| 118 |
+
if self.qk_rms_norm:
|
| 119 |
+
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
| 120 |
+
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
| 121 |
+
|
| 122 |
+
self.to_out = nn.Linear(channels, channels) if not use_lora else LoRALinear(channels, channels, rank=lora_rank, lr_rate=lora_lr_rate)
|
| 123 |
+
|
| 124 |
+
if use_rope:
|
| 125 |
+
self.rope = RotaryPositionEmbedder(channels)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 128 |
+
B, L, C = x.shape
|
| 129 |
+
if self._type == "self":
|
| 130 |
+
qkv = self.to_qkv(x)
|
| 131 |
+
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
| 132 |
+
if self.use_rope:
|
| 133 |
+
q, k, v = qkv.unbind(dim=2)
|
| 134 |
+
q, k = self.rope(q, k, indices)
|
| 135 |
+
qkv = torch.stack([q, k, v], dim=2)
|
| 136 |
+
if self.attn_mode == "full":
|
| 137 |
+
if self.qk_rms_norm:
|
| 138 |
+
q, k, v = qkv.unbind(dim=2)
|
| 139 |
+
q = self.q_rms_norm(q)
|
| 140 |
+
k = self.k_rms_norm(k)
|
| 141 |
+
h = scaled_dot_product_attention(q, k, v)
|
| 142 |
+
else:
|
| 143 |
+
h = scaled_dot_product_attention(qkv)
|
| 144 |
+
elif self.attn_mode == "windowed":
|
| 145 |
+
raise NotImplementedError("Windowed attention is not yet implemented")
|
| 146 |
+
else:
|
| 147 |
+
Lkv = context.shape[1]
|
| 148 |
+
q = self.to_q(x)
|
| 149 |
+
kv = self.to_kv(context)
|
| 150 |
+
q = q.reshape(B, L, self.num_heads, -1)
|
| 151 |
+
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
| 152 |
+
if self.qk_rms_norm:
|
| 153 |
+
q = self.q_rms_norm(q)
|
| 154 |
+
k, v = kv.unbind(dim=2)
|
| 155 |
+
k = self.k_rms_norm(k)
|
| 156 |
+
h = scaled_dot_product_attention(q, k, v)
|
| 157 |
+
else:
|
| 158 |
+
h = scaled_dot_product_attention(q, kv)
|
| 159 |
+
h = h.reshape(B, L, -1)
|
| 160 |
+
h = self.to_out(h)
|
| 161 |
+
return h
|
anigen/modules/norm.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LayerNorm32(nn.LayerNorm):
|
| 6 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
return super().forward(x.float()).type(x.dtype)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GroupNorm32(nn.GroupNorm):
|
| 11 |
+
"""
|
| 12 |
+
A GroupNorm layer that converts to float32 before the forward pass.
|
| 13 |
+
"""
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return super().forward(x.float()).type(x.dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ChannelLayerNorm32(LayerNorm32):
|
| 19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
DIM = x.dim()
|
| 21 |
+
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
| 22 |
+
x = super().forward(x)
|
| 23 |
+
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
| 24 |
+
return x
|
| 25 |
+
|
anigen/modules/sparse/__init__.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
BACKEND = 'spconv'
|
| 4 |
+
DEBUG = False
|
| 5 |
+
ATTN = 'flash_attn'
|
| 6 |
+
|
| 7 |
+
def __from_env():
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
global BACKEND
|
| 11 |
+
global DEBUG
|
| 12 |
+
global ATTN
|
| 13 |
+
|
| 14 |
+
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
| 15 |
+
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
| 16 |
+
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
| 17 |
+
if env_sparse_attn is None:
|
| 18 |
+
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
| 19 |
+
|
| 20 |
+
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
| 21 |
+
BACKEND = env_sparse_backend
|
| 22 |
+
if env_sparse_debug is not None:
|
| 23 |
+
DEBUG = env_sparse_debug == '1'
|
| 24 |
+
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
| 25 |
+
ATTN = env_sparse_attn
|
| 26 |
+
|
| 27 |
+
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__from_env()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
| 34 |
+
global BACKEND
|
| 35 |
+
BACKEND = backend
|
| 36 |
+
|
| 37 |
+
def set_debug(debug: bool):
|
| 38 |
+
global DEBUG
|
| 39 |
+
DEBUG = debug
|
| 40 |
+
|
| 41 |
+
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
| 42 |
+
global ATTN
|
| 43 |
+
ATTN = attn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
import importlib
|
| 47 |
+
|
| 48 |
+
__attributes = {
|
| 49 |
+
'SparseTensor': 'basic',
|
| 50 |
+
'sparse_batch_broadcast': 'basic',
|
| 51 |
+
'sparse_batch_op': 'basic',
|
| 52 |
+
'sparse_cat': 'basic',
|
| 53 |
+
'sparse_unbind': 'basic',
|
| 54 |
+
'SparseGroupNorm': 'norm',
|
| 55 |
+
'SparseLayerNorm': 'norm',
|
| 56 |
+
'SparseGroupNorm32': 'norm',
|
| 57 |
+
'SparseLayerNorm32': 'norm',
|
| 58 |
+
'SparseReLU': 'nonlinearity',
|
| 59 |
+
'SparseSiLU': 'nonlinearity',
|
| 60 |
+
'SparseGELU': 'nonlinearity',
|
| 61 |
+
'SparseActivation': 'nonlinearity',
|
| 62 |
+
'SparseLinear': 'linear',
|
| 63 |
+
'sparse_scaled_dot_product_attention': 'attention',
|
| 64 |
+
'SerializeMode': 'attention',
|
| 65 |
+
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
| 66 |
+
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
| 67 |
+
'SparseMultiHeadAttention': 'attention',
|
| 68 |
+
'SparseConv3d': 'conv',
|
| 69 |
+
'SparseInverseConv3d': 'conv',
|
| 70 |
+
'SparseDownsample': 'spatial',
|
| 71 |
+
'SparseUpsample': 'spatial',
|
| 72 |
+
'SparseSubdivide' : 'spatial'
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
__submodules = ['transformer']
|
| 76 |
+
|
| 77 |
+
__all__ = list(__attributes.keys()) + __submodules
|
| 78 |
+
|
| 79 |
+
def __getattr__(name):
|
| 80 |
+
if name not in globals():
|
| 81 |
+
if name in __attributes:
|
| 82 |
+
module_name = __attributes[name]
|
| 83 |
+
module = importlib.import_module(f".{module_name}", __name__)
|
| 84 |
+
globals()[name] = getattr(module, name)
|
| 85 |
+
elif name in __submodules:
|
| 86 |
+
module = importlib.import_module(f".{name}", __name__)
|
| 87 |
+
globals()[name] = module
|
| 88 |
+
else:
|
| 89 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
| 90 |
+
return globals()[name]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# For Pylance
|
| 94 |
+
if __name__ == '__main__':
|
| 95 |
+
from .basic import *
|
| 96 |
+
from .norm import *
|
| 97 |
+
from .nonlinearity import *
|
| 98 |
+
from .linear import *
|
| 99 |
+
from .attention import *
|
| 100 |
+
from .conv import *
|
| 101 |
+
from .spatial import *
|
| 102 |
+
import transformer
|
anigen/modules/sparse/attention/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .full_attn import *
|
| 2 |
+
from .serialized_attn import *
|
| 3 |
+
from .windowed_attn import *
|
| 4 |
+
from .modules import *
|
| 5 |
+
from .windowed_attn_cross import *
|
anigen/modules/sparse/attention/full_attn.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
from .. import SparseTensor
|
| 4 |
+
from .. import DEBUG, ATTN
|
| 5 |
+
|
| 6 |
+
if ATTN == 'xformers':
|
| 7 |
+
import xformers.ops as xops
|
| 8 |
+
elif ATTN == 'flash_attn':
|
| 9 |
+
import flash_attn
|
| 10 |
+
else:
|
| 11 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'sparse_scaled_dot_product_attention',
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@overload
|
| 20 |
+
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
|
| 21 |
+
"""
|
| 22 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
| 26 |
+
"""
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
@overload
|
| 30 |
+
def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
|
| 31 |
+
"""
|
| 32 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
|
| 36 |
+
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
|
| 37 |
+
"""
|
| 38 |
+
...
|
| 39 |
+
|
| 40 |
+
@overload
|
| 41 |
+
def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
|
| 47 |
+
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
|
| 48 |
+
"""
|
| 49 |
+
...
|
| 50 |
+
|
| 51 |
+
@overload
|
| 52 |
+
def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
|
| 53 |
+
"""
|
| 54 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
|
| 58 |
+
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
|
| 59 |
+
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
|
| 60 |
+
|
| 61 |
+
Note:
|
| 62 |
+
k and v are assumed to have the same coordinate map.
|
| 63 |
+
"""
|
| 64 |
+
...
|
| 65 |
+
|
| 66 |
+
@overload
|
| 67 |
+
def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
|
| 68 |
+
"""
|
| 69 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
|
| 73 |
+
k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
|
| 74 |
+
v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
|
| 75 |
+
"""
|
| 76 |
+
...
|
| 77 |
+
|
| 78 |
+
@overload
|
| 79 |
+
def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Apply scaled dot product attention to a sparse tensor.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
|
| 85 |
+
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
|
| 86 |
+
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
|
| 87 |
+
"""
|
| 88 |
+
...
|
| 89 |
+
|
| 90 |
+
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
| 91 |
+
arg_names_dict = {
|
| 92 |
+
1: ['qkv'],
|
| 93 |
+
2: ['q', 'kv'],
|
| 94 |
+
3: ['q', 'k', 'v']
|
| 95 |
+
}
|
| 96 |
+
num_all_args = len(args) + len(kwargs)
|
| 97 |
+
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
|
| 98 |
+
for key in arg_names_dict[num_all_args][len(args):]:
|
| 99 |
+
assert key in kwargs, f"Missing argument {key}"
|
| 100 |
+
|
| 101 |
+
if num_all_args == 1:
|
| 102 |
+
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
| 103 |
+
assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
|
| 104 |
+
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
| 105 |
+
device = qkv.device
|
| 106 |
+
|
| 107 |
+
s = qkv
|
| 108 |
+
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
| 109 |
+
kv_seqlen = q_seqlen
|
| 110 |
+
qkv = qkv.feats # [T, 3, H, C]
|
| 111 |
+
|
| 112 |
+
elif num_all_args == 2:
|
| 113 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 114 |
+
kv = args[1] if len(args) > 1 else kwargs['kv']
|
| 115 |
+
assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
|
| 116 |
+
isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
|
| 117 |
+
f"Invalid types, got {type(q)} and {type(kv)}"
|
| 118 |
+
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
|
| 119 |
+
device = q.device
|
| 120 |
+
|
| 121 |
+
if isinstance(q, SparseTensor):
|
| 122 |
+
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
|
| 123 |
+
s = q
|
| 124 |
+
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
| 125 |
+
q = q.feats # [T_Q, H, C]
|
| 126 |
+
else:
|
| 127 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
|
| 128 |
+
s = None
|
| 129 |
+
N, L, H, C = q.shape
|
| 130 |
+
q_seqlen = [L] * N
|
| 131 |
+
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
| 132 |
+
|
| 133 |
+
if isinstance(kv, SparseTensor):
|
| 134 |
+
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
|
| 135 |
+
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
| 136 |
+
kv = kv.feats # [T_KV, 2, H, C]
|
| 137 |
+
else:
|
| 138 |
+
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
|
| 139 |
+
N, L, _, H, C = kv.shape
|
| 140 |
+
kv_seqlen = [L] * N
|
| 141 |
+
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
| 142 |
+
|
| 143 |
+
elif num_all_args == 3:
|
| 144 |
+
q = args[0] if len(args) > 0 else kwargs['q']
|
| 145 |
+
k = args[1] if len(args) > 1 else kwargs['k']
|
| 146 |
+
v = args[2] if len(args) > 2 else kwargs['v']
|
| 147 |
+
assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
|
| 148 |
+
isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
|
| 149 |
+
f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
|
| 150 |
+
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
|
| 151 |
+
device = q.device
|
| 152 |
+
|
| 153 |
+
if isinstance(q, SparseTensor):
|
| 154 |
+
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
|
| 155 |
+
s = q
|
| 156 |
+
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
| 157 |
+
q = q.feats # [T_Q, H, Ci]
|
| 158 |
+
else:
|
| 159 |
+
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
|
| 160 |
+
s = None
|
| 161 |
+
N, L, H, CI = q.shape
|
| 162 |
+
q_seqlen = [L] * N
|
| 163 |
+
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
| 164 |
+
|
| 165 |
+
if isinstance(k, SparseTensor):
|
| 166 |
+
assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
|
| 167 |
+
assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
|
| 168 |
+
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
| 169 |
+
k = k.feats # [T_KV, H, Ci]
|
| 170 |
+
v = v.feats # [T_KV, H, Co]
|
| 171 |
+
else:
|
| 172 |
+
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
|
| 173 |
+
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
|
| 174 |
+
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
| 175 |
+
kv_seqlen = [L] * N
|
| 176 |
+
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
| 177 |
+
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
| 178 |
+
|
| 179 |
+
if DEBUG:
|
| 180 |
+
if s is not None:
|
| 181 |
+
for i in range(s.shape[0]):
|
| 182 |
+
assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
|
| 183 |
+
if num_all_args in [2, 3]:
|
| 184 |
+
assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
|
| 185 |
+
if num_all_args == 3:
|
| 186 |
+
assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
|
| 187 |
+
assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
|
| 188 |
+
|
| 189 |
+
if ATTN == 'xformers':
|
| 190 |
+
if num_all_args == 1:
|
| 191 |
+
q, k, v = qkv.unbind(dim=1)
|
| 192 |
+
elif num_all_args == 2:
|
| 193 |
+
k, v = kv.unbind(dim=1)
|
| 194 |
+
q = q.unsqueeze(0)
|
| 195 |
+
k = k.unsqueeze(0)
|
| 196 |
+
v = v.unsqueeze(0)
|
| 197 |
+
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
| 198 |
+
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
| 199 |
+
elif ATTN == 'flash_attn':
|
| 200 |
+
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
| 201 |
+
if num_all_args in [2, 3]:
|
| 202 |
+
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
| 203 |
+
if num_all_args == 1:
|
| 204 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
| 205 |
+
elif num_all_args == 2:
|
| 206 |
+
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
| 207 |
+
elif num_all_args == 3:
|
| 208 |
+
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 211 |
+
|
| 212 |
+
if s is not None:
|
| 213 |
+
return s.replace(out)
|
| 214 |
+
else:
|
| 215 |
+
return out.reshape(N, L, H, -1)
|
anigen/modules/sparse/attention/modules.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .. import SparseTensor
|
| 6 |
+
from .full_attn import sparse_scaled_dot_product_attention
|
| 7 |
+
from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
|
| 8 |
+
from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
|
| 9 |
+
from .windowed_attn_cross import sparse_windowed_scaled_dot_product_cross_attention
|
| 10 |
+
from ...attention import RotaryPositionEmbedder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SparseMultiHeadRMSNorm(nn.Module):
|
| 14 |
+
def __init__(self, dim: int, heads: int):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.scale = dim ** 0.5
|
| 17 |
+
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
| 18 |
+
|
| 19 |
+
def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
| 20 |
+
x_type = x.dtype
|
| 21 |
+
x = x.float()
|
| 22 |
+
if isinstance(x, SparseTensor):
|
| 23 |
+
x = x.replace(F.normalize(x.feats, dim=-1))
|
| 24 |
+
else:
|
| 25 |
+
x = F.normalize(x, dim=-1)
|
| 26 |
+
return (x * self.gamma * self.scale).to(x_type)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SparseMultiHeadAttention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
channels: int,
|
| 33 |
+
num_heads: int,
|
| 34 |
+
ctx_channels: Optional[int] = None,
|
| 35 |
+
type: Literal["self", "cross"] = "self",
|
| 36 |
+
attn_mode: Literal["full", "serialized", "windowed"] = "full",
|
| 37 |
+
window_size: Optional[int] = None,
|
| 38 |
+
shift_sequence: Optional[int] = None,
|
| 39 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 40 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 41 |
+
qkv_bias: bool = True,
|
| 42 |
+
use_rope: bool = False,
|
| 43 |
+
qk_rms_norm: bool = False,
|
| 44 |
+
cross_attn_cache_suffix: str = '',
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
assert channels % num_heads == 0
|
| 48 |
+
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
| 49 |
+
assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
|
| 50 |
+
assert type == "self" or (attn_mode == "full" or attn_mode == "windowed"), "Cross-attention only supports full and windowed attention"
|
| 51 |
+
assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
|
| 52 |
+
self.channels = channels
|
| 53 |
+
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self._type = type
|
| 56 |
+
self.attn_mode = attn_mode
|
| 57 |
+
self.window_size = window_size
|
| 58 |
+
self.shift_sequence = shift_sequence
|
| 59 |
+
self.shift_window = shift_window
|
| 60 |
+
self.serialize_mode = serialize_mode
|
| 61 |
+
self.use_rope = use_rope
|
| 62 |
+
self.qk_rms_norm = qk_rms_norm
|
| 63 |
+
|
| 64 |
+
if self._type == "self":
|
| 65 |
+
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
| 66 |
+
else:
|
| 67 |
+
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
| 68 |
+
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
| 69 |
+
self.cross_attn_cache_suffix = cross_attn_cache_suffix
|
| 70 |
+
|
| 71 |
+
if self.qk_rms_norm:
|
| 72 |
+
self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
| 73 |
+
self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
| 74 |
+
|
| 75 |
+
self.to_out = nn.Linear(channels, channels)
|
| 76 |
+
|
| 77 |
+
if use_rope:
|
| 78 |
+
self.rope = RotaryPositionEmbedder(channels)
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
| 82 |
+
if isinstance(x, SparseTensor):
|
| 83 |
+
return x.replace(module(x.feats))
|
| 84 |
+
else:
|
| 85 |
+
return module(x)
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
|
| 89 |
+
if isinstance(x, SparseTensor):
|
| 90 |
+
return x.reshape(*shape)
|
| 91 |
+
else:
|
| 92 |
+
return x.reshape(*x.shape[:2], *shape)
|
| 93 |
+
|
| 94 |
+
def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
|
| 95 |
+
if isinstance(x, SparseTensor):
|
| 96 |
+
x_feats = x.feats.unsqueeze(0)
|
| 97 |
+
else:
|
| 98 |
+
x_feats = x
|
| 99 |
+
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
| 100 |
+
return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
|
| 101 |
+
|
| 102 |
+
def _rope(self, qkv: SparseTensor) -> SparseTensor:
|
| 103 |
+
q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
|
| 104 |
+
q, k = self.rope(q, k, qkv.coords[:, 1:])
|
| 105 |
+
qkv = qkv.replace(torch.stack([q, k, v], dim=1))
|
| 106 |
+
return qkv
|
| 107 |
+
|
| 108 |
+
def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
|
| 109 |
+
if self._type == "self":
|
| 110 |
+
qkv = self._linear(self.to_qkv, x)
|
| 111 |
+
qkv = self._fused_pre(qkv, num_fused=3)
|
| 112 |
+
if self.use_rope:
|
| 113 |
+
qkv = self._rope(qkv)
|
| 114 |
+
if self.qk_rms_norm:
|
| 115 |
+
q, k, v = qkv.unbind(dim=1)
|
| 116 |
+
q = self.q_rms_norm(q)
|
| 117 |
+
k = self.k_rms_norm(k)
|
| 118 |
+
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
| 119 |
+
if self.attn_mode == "full":
|
| 120 |
+
h = sparse_scaled_dot_product_attention(qkv)
|
| 121 |
+
elif self.attn_mode == "serialized":
|
| 122 |
+
h = sparse_serialized_scaled_dot_product_self_attention(
|
| 123 |
+
qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
|
| 124 |
+
)
|
| 125 |
+
elif self.attn_mode == "windowed":
|
| 126 |
+
h = sparse_windowed_scaled_dot_product_self_attention(
|
| 127 |
+
qkv, self.window_size, shift_window=self.shift_window
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
q = self._linear(self.to_q, x)
|
| 131 |
+
q = self._reshape_chs(q, (self.num_heads, -1))
|
| 132 |
+
kv = self._linear(self.to_kv, context)
|
| 133 |
+
kv = self._fused_pre(kv, num_fused=2)
|
| 134 |
+
if self.qk_rms_norm:
|
| 135 |
+
q = self.q_rms_norm(q)
|
| 136 |
+
k, v = kv.unbind(dim=1)
|
| 137 |
+
k = self.k_rms_norm(k)
|
| 138 |
+
kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
|
| 139 |
+
if self.attn_mode == "full":
|
| 140 |
+
h = sparse_scaled_dot_product_attention(q, kv)
|
| 141 |
+
elif self.attn_mode == "windowed":
|
| 142 |
+
q = self._fused_pre(q, num_fused=1)
|
| 143 |
+
h = sparse_windowed_scaled_dot_product_cross_attention(
|
| 144 |
+
q, kv, self.window_size, shift_window=self.shift_window,
|
| 145 |
+
cache_suffix=self.cross_attn_cache_suffix
|
| 146 |
+
)
|
| 147 |
+
elif self.attn_mode == "serialized":
|
| 148 |
+
raise NotImplementedError("Serialized attention is not supported for cross-attention")
|
| 149 |
+
h = self._reshape_chs(h, (-1,))
|
| 150 |
+
h = self._linear(self.to_out, h)
|
| 151 |
+
return h
|
anigen/modules/sparse/attention/serialized_attn.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
from .. import SparseTensor
|
| 6 |
+
from .. import DEBUG, ATTN
|
| 7 |
+
|
| 8 |
+
if ATTN == 'xformers':
|
| 9 |
+
import xformers.ops as xops
|
| 10 |
+
elif ATTN == 'flash_attn':
|
| 11 |
+
import flash_attn
|
| 12 |
+
else:
|
| 13 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
'sparse_serialized_scaled_dot_product_self_attention',
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SerializeMode(Enum):
|
| 22 |
+
Z_ORDER = 0
|
| 23 |
+
Z_ORDER_TRANSPOSED = 1
|
| 24 |
+
HILBERT = 2
|
| 25 |
+
HILBERT_TRANSPOSED = 3
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
SerializeModes = [
|
| 29 |
+
SerializeMode.Z_ORDER,
|
| 30 |
+
SerializeMode.Z_ORDER_TRANSPOSED,
|
| 31 |
+
SerializeMode.HILBERT,
|
| 32 |
+
SerializeMode.HILBERT_TRANSPOSED
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def calc_serialization(
|
| 37 |
+
tensor: SparseTensor,
|
| 38 |
+
window_size: int,
|
| 39 |
+
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
| 40 |
+
shift_sequence: int = 0,
|
| 41 |
+
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
| 42 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
| 43 |
+
"""
|
| 44 |
+
Calculate serialization and partitioning for a set of coordinates.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
tensor (SparseTensor): The input tensor.
|
| 48 |
+
window_size (int): The window size to use.
|
| 49 |
+
serialize_mode (SerializeMode): The serialization mode to use.
|
| 50 |
+
shift_sequence (int): The shift of serialized sequence.
|
| 51 |
+
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
(torch.Tensor, torch.Tensor): Forwards and backwards indices.
|
| 55 |
+
"""
|
| 56 |
+
fwd_indices = []
|
| 57 |
+
bwd_indices = []
|
| 58 |
+
seq_lens = []
|
| 59 |
+
seq_batch_indices = []
|
| 60 |
+
offsets = [0]
|
| 61 |
+
|
| 62 |
+
if 'vox2seq' not in globals():
|
| 63 |
+
import vox2seq
|
| 64 |
+
|
| 65 |
+
# Serialize the input
|
| 66 |
+
serialize_coords = tensor.coords[:, 1:].clone()
|
| 67 |
+
serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
|
| 68 |
+
if serialize_mode == SerializeMode.Z_ORDER:
|
| 69 |
+
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
|
| 70 |
+
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
|
| 71 |
+
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
|
| 72 |
+
elif serialize_mode == SerializeMode.HILBERT:
|
| 73 |
+
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
|
| 74 |
+
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
|
| 75 |
+
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Unknown serialize mode: {serialize_mode}")
|
| 78 |
+
|
| 79 |
+
for bi, s in enumerate(tensor.layout):
|
| 80 |
+
num_points = s.stop - s.start
|
| 81 |
+
num_windows = (num_points + window_size - 1) // window_size
|
| 82 |
+
valid_window_size = num_points / num_windows
|
| 83 |
+
to_ordered = torch.argsort(code[s.start:s.stop])
|
| 84 |
+
if num_windows == 1:
|
| 85 |
+
fwd_indices.append(to_ordered)
|
| 86 |
+
bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
|
| 87 |
+
fwd_indices[-1] += s.start
|
| 88 |
+
bwd_indices[-1] += offsets[-1]
|
| 89 |
+
seq_lens.append(num_points)
|
| 90 |
+
seq_batch_indices.append(bi)
|
| 91 |
+
offsets.append(offsets[-1] + seq_lens[-1])
|
| 92 |
+
else:
|
| 93 |
+
# Partition the input
|
| 94 |
+
offset = 0
|
| 95 |
+
mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
|
| 96 |
+
split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
|
| 97 |
+
bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
|
| 98 |
+
for i in range(num_windows):
|
| 99 |
+
mid = mids[i]
|
| 100 |
+
valid_start = split[i]
|
| 101 |
+
valid_end = split[i + 1]
|
| 102 |
+
padded_start = math.floor(mid - 0.5 * window_size)
|
| 103 |
+
padded_end = padded_start + window_size
|
| 104 |
+
fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
|
| 105 |
+
offset += valid_start - padded_start
|
| 106 |
+
bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
|
| 107 |
+
offset += padded_end - valid_start
|
| 108 |
+
fwd_indices[-1] += s.start
|
| 109 |
+
seq_lens.extend([window_size] * num_windows)
|
| 110 |
+
seq_batch_indices.extend([bi] * num_windows)
|
| 111 |
+
bwd_indices.append(bwd_index + offsets[-1])
|
| 112 |
+
offsets.append(offsets[-1] + num_windows * window_size)
|
| 113 |
+
|
| 114 |
+
fwd_indices = torch.cat(fwd_indices)
|
| 115 |
+
bwd_indices = torch.cat(bwd_indices)
|
| 116 |
+
|
| 117 |
+
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def sparse_serialized_scaled_dot_product_self_attention(
|
| 121 |
+
qkv: SparseTensor,
|
| 122 |
+
window_size: int,
|
| 123 |
+
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
| 124 |
+
shift_sequence: int = 0,
|
| 125 |
+
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
| 126 |
+
) -> SparseTensor:
|
| 127 |
+
"""
|
| 128 |
+
Apply serialized scaled dot product self attention to a sparse tensor.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
| 132 |
+
window_size (int): The window size to use.
|
| 133 |
+
serialize_mode (SerializeMode): The serialization mode to use.
|
| 134 |
+
shift_sequence (int): The shift of serialized sequence.
|
| 135 |
+
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
| 136 |
+
shift (int): The shift to use.
|
| 137 |
+
"""
|
| 138 |
+
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
| 139 |
+
|
| 140 |
+
serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
|
| 141 |
+
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
| 142 |
+
if serialization_spatial_cache is None:
|
| 143 |
+
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
|
| 144 |
+
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
|
| 145 |
+
else:
|
| 146 |
+
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
|
| 147 |
+
|
| 148 |
+
M = fwd_indices.shape[0]
|
| 149 |
+
T = qkv.feats.shape[0]
|
| 150 |
+
H = qkv.feats.shape[2]
|
| 151 |
+
C = qkv.feats.shape[3]
|
| 152 |
+
|
| 153 |
+
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
| 154 |
+
|
| 155 |
+
if DEBUG:
|
| 156 |
+
start = 0
|
| 157 |
+
qkv_coords = qkv.coords[fwd_indices]
|
| 158 |
+
for i in range(len(seq_lens)):
|
| 159 |
+
assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
|
| 160 |
+
start += seq_lens[i]
|
| 161 |
+
|
| 162 |
+
if all([seq_len == window_size for seq_len in seq_lens]):
|
| 163 |
+
B = len(seq_lens)
|
| 164 |
+
N = window_size
|
| 165 |
+
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
|
| 166 |
+
if ATTN == 'xformers':
|
| 167 |
+
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 168 |
+
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 169 |
+
elif ATTN == 'flash_attn':
|
| 170 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
| 171 |
+
else:
|
| 172 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 173 |
+
out = out.reshape(B * N, H, C) # [M, H, C]
|
| 174 |
+
else:
|
| 175 |
+
if ATTN == 'xformers':
|
| 176 |
+
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
|
| 177 |
+
q = q.unsqueeze(0) # [1, M, H, C]
|
| 178 |
+
k = k.unsqueeze(0) # [1, M, H, C]
|
| 179 |
+
v = v.unsqueeze(0) # [1, M, H, C]
|
| 180 |
+
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
| 181 |
+
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
|
| 182 |
+
elif ATTN == 'flash_attn':
|
| 183 |
+
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 184 |
+
.to(qkv.device).int()
|
| 185 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
| 186 |
+
|
| 187 |
+
out = out[bwd_indices] # [T, H, C]
|
| 188 |
+
|
| 189 |
+
if DEBUG:
|
| 190 |
+
qkv_coords = qkv_coords[bwd_indices]
|
| 191 |
+
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
|
| 192 |
+
|
| 193 |
+
return qkv.replace(out)
|
anigen/modules/sparse/attention/windowed_attn.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from .. import SparseTensor
|
| 5 |
+
from .. import DEBUG, ATTN
|
| 6 |
+
|
| 7 |
+
if ATTN == 'xformers':
|
| 8 |
+
import xformers.ops as xops
|
| 9 |
+
elif ATTN == 'flash_attn':
|
| 10 |
+
import flash_attn
|
| 11 |
+
else:
|
| 12 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'sparse_windowed_scaled_dot_product_self_attention',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calc_window_partition(
|
| 21 |
+
tensor: SparseTensor,
|
| 22 |
+
window_size: Union[int, Tuple[int, ...]],
|
| 23 |
+
shift_window: Union[int, Tuple[int, ...]] = 0
|
| 24 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
| 25 |
+
"""
|
| 26 |
+
Calculate serialization and partitioning for a set of coordinates.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
tensor (SparseTensor): The input tensor.
|
| 30 |
+
window_size (int): The window size to use.
|
| 31 |
+
shift_window (Tuple[int, ...]): The shift of serialized coordinates.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(torch.Tensor): Forwards indices.
|
| 35 |
+
(torch.Tensor): Backwards indices.
|
| 36 |
+
(List[int]): Sequence lengths.
|
| 37 |
+
(List[int]): Sequence batch indices.
|
| 38 |
+
"""
|
| 39 |
+
DIM = tensor.coords.shape[1] - 1
|
| 40 |
+
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
| 41 |
+
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
| 42 |
+
shifted_coords = tensor.coords.clone().detach()
|
| 43 |
+
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
|
| 46 |
+
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
| 47 |
+
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
| 48 |
+
|
| 49 |
+
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
| 50 |
+
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
| 51 |
+
fwd_indices = torch.argsort(shifted_indices)
|
| 52 |
+
bwd_indices = torch.empty_like(fwd_indices)
|
| 53 |
+
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
| 54 |
+
seq_lens = torch.bincount(shifted_indices)
|
| 55 |
+
seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
|
| 56 |
+
mask = seq_lens != 0
|
| 57 |
+
seq_lens = seq_lens[mask].tolist()
|
| 58 |
+
seq_batch_indices = seq_batch_indices[mask].tolist()
|
| 59 |
+
|
| 60 |
+
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def sparse_windowed_scaled_dot_product_self_attention(
|
| 64 |
+
qkv: SparseTensor,
|
| 65 |
+
window_size: int,
|
| 66 |
+
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
| 67 |
+
) -> SparseTensor:
|
| 68 |
+
"""
|
| 69 |
+
Apply windowed scaled dot product self attention to a sparse tensor.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
| 73 |
+
window_size (int): The window size to use.
|
| 74 |
+
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
| 75 |
+
shift (int): The shift to use.
|
| 76 |
+
"""
|
| 77 |
+
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
| 78 |
+
|
| 79 |
+
serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
|
| 80 |
+
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
| 81 |
+
if serialization_spatial_cache is None:
|
| 82 |
+
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
|
| 83 |
+
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
|
| 84 |
+
else:
|
| 85 |
+
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
|
| 86 |
+
|
| 87 |
+
M = fwd_indices.shape[0]
|
| 88 |
+
T = qkv.feats.shape[0]
|
| 89 |
+
H = qkv.feats.shape[2]
|
| 90 |
+
C = qkv.feats.shape[3]
|
| 91 |
+
|
| 92 |
+
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
| 93 |
+
|
| 94 |
+
if DEBUG:
|
| 95 |
+
start = 0
|
| 96 |
+
qkv_coords = qkv.coords[fwd_indices]
|
| 97 |
+
for i in range(len(seq_lens)):
|
| 98 |
+
seq_coords = qkv_coords[start:start+seq_lens[i]]
|
| 99 |
+
assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
|
| 100 |
+
assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
|
| 101 |
+
f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
|
| 102 |
+
start += seq_lens[i]
|
| 103 |
+
|
| 104 |
+
if all([seq_len == window_size for seq_len in seq_lens]):
|
| 105 |
+
B = len(seq_lens)
|
| 106 |
+
N = window_size
|
| 107 |
+
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
|
| 108 |
+
if ATTN == 'xformers':
|
| 109 |
+
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
| 110 |
+
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 111 |
+
elif ATTN == 'flash_attn':
|
| 112 |
+
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 115 |
+
out = out.reshape(B * N, H, C) # [M, H, C]
|
| 116 |
+
else:
|
| 117 |
+
if ATTN == 'xformers':
|
| 118 |
+
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
|
| 119 |
+
q = q.unsqueeze(0) # [1, M, H, C]
|
| 120 |
+
k = k.unsqueeze(0) # [1, M, H, C]
|
| 121 |
+
v = v.unsqueeze(0) # [1, M, H, C]
|
| 122 |
+
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
| 123 |
+
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
|
| 124 |
+
elif ATTN == 'flash_attn':
|
| 125 |
+
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 126 |
+
.to(qkv.device).int()
|
| 127 |
+
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
| 128 |
+
|
| 129 |
+
out = out[bwd_indices] # [T, H, C]
|
| 130 |
+
|
| 131 |
+
if DEBUG:
|
| 132 |
+
qkv_coords = qkv_coords[bwd_indices]
|
| 133 |
+
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
|
| 134 |
+
|
| 135 |
+
return qkv.replace(out)
|
anigen/modules/sparse/attention/windowed_attn_cross.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from .. import SparseTensor
|
| 5 |
+
from .. import DEBUG, ATTN
|
| 6 |
+
|
| 7 |
+
if ATTN == 'xformers':
|
| 8 |
+
import xformers.ops as xops
|
| 9 |
+
elif ATTN == 'flash_attn':
|
| 10 |
+
import flash_attn
|
| 11 |
+
else:
|
| 12 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'sparse_windowed_scaled_dot_product_cross_attention',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def calc_window_partition_cross(
|
| 21 |
+
tensor: SparseTensor,
|
| 22 |
+
context: SparseTensor,
|
| 23 |
+
window_size: Union[int, Tuple[int, ...]],
|
| 24 |
+
shift_window: Union[int, Tuple[int, ...]] = 0
|
| 25 |
+
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
| 26 |
+
"""
|
| 27 |
+
Calculate serialization and partitioning for a set of coordinates.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
tensor (SparseTensor): The input tensor.
|
| 31 |
+
window_size (int): The window size to use.
|
| 32 |
+
shift_window (Tuple[int, ...]): The shift of serialized coordinates.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
(torch.Tensor): Forwards indices.
|
| 36 |
+
(torch.Tensor): Backwards indices.
|
| 37 |
+
(List[int]): Sequence lengths.
|
| 38 |
+
(List[int]): Sequence batch indices.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def calc_window_partition_(tensor, window_size, shift_window):
|
| 42 |
+
DIM = tensor.coords.shape[1] - 1
|
| 43 |
+
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
| 44 |
+
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
| 45 |
+
shifted_coords = tensor.coords.clone().detach()
|
| 46 |
+
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
| 47 |
+
|
| 48 |
+
MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
|
| 49 |
+
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
| 50 |
+
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
| 51 |
+
|
| 52 |
+
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
| 53 |
+
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
| 54 |
+
fwd_indices = torch.argsort(shifted_indices)
|
| 55 |
+
bwd_indices = torch.empty_like(fwd_indices)
|
| 56 |
+
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
| 57 |
+
seq_lens = torch.bincount(shifted_indices)
|
| 58 |
+
seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
|
| 59 |
+
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
|
| 60 |
+
|
| 61 |
+
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition_(tensor, window_size, shift_window)
|
| 62 |
+
fwd_indices_context, bwd_indices_context, seq_lens_context, seq_batch_indices_context = calc_window_partition_(context, window_size, shift_window)
|
| 63 |
+
# Pad the shorter one to the shape of the other with 0 tail
|
| 64 |
+
max_len = max(seq_lens.shape[0], seq_lens_context.shape[0])
|
| 65 |
+
if seq_lens.shape[0] < max_len:
|
| 66 |
+
pad_size = max_len - seq_lens.shape[0]
|
| 67 |
+
seq_lens = torch.cat([seq_lens, torch.zeros(pad_size, dtype=seq_lens.dtype, device=seq_lens.device)])
|
| 68 |
+
if seq_lens_context.shape[0] < max_len:
|
| 69 |
+
pad_size = max_len - seq_lens_context.shape[0]
|
| 70 |
+
seq_lens_context = torch.cat([seq_lens_context, torch.zeros(pad_size, dtype=seq_lens_context.dtype, device=seq_lens_context.device)])
|
| 71 |
+
mask = (seq_lens != 0) | (seq_lens_context != 0)
|
| 72 |
+
seq_lens = seq_lens[mask].tolist()
|
| 73 |
+
seq_lens_context = seq_lens_context[mask].tolist()
|
| 74 |
+
|
| 75 |
+
return fwd_indices, bwd_indices, seq_lens, fwd_indices_context, bwd_indices_context, seq_lens_context
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def sparse_windowed_scaled_dot_product_cross_attention(
|
| 79 |
+
q: SparseTensor,
|
| 80 |
+
kv: SparseTensor,
|
| 81 |
+
window_size: int,
|
| 82 |
+
shift_window: Tuple[int, int, int] = (0, 0, 0),
|
| 83 |
+
cache_suffix: str = '',
|
| 84 |
+
) -> SparseTensor:
|
| 85 |
+
"""
|
| 86 |
+
Apply windowed scaled dot product cross attention to a sparse tensor.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
q, kv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
| 90 |
+
window_size (int): The window size to use.
|
| 91 |
+
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
| 92 |
+
shift (int): The shift to use.
|
| 93 |
+
"""
|
| 94 |
+
assert len(q.shape) == 4 and q.shape[1] == 1, f"Invalid shape for q, got {q.shape}, expected [N, *, 1, H, C]"
|
| 95 |
+
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
|
| 96 |
+
|
| 97 |
+
serialization_spatial_cache_name_q = f'window_partition_{window_size}_{shift_window}_cross_q' + cache_suffix
|
| 98 |
+
serialization_spatial_cache_q = q.get_spatial_cache(serialization_spatial_cache_name_q)
|
| 99 |
+
serialization_spatial_cache_name_kv = f'window_partition_{window_size}_{shift_window}_cross_kv' + cache_suffix
|
| 100 |
+
serialization_spatial_cache_kv = kv.get_spatial_cache(serialization_spatial_cache_name_kv)
|
| 101 |
+
if serialization_spatial_cache_q is None or serialization_spatial_cache_kv is None:
|
| 102 |
+
q_fwd_indices, q_bwd_indices, q_seq_lens, kv_fwd_indices, kv_bwd_indices, kv_seq_lens = calc_window_partition_cross(q, kv, window_size, shift_window)
|
| 103 |
+
q.register_spatial_cache(serialization_spatial_cache_name_q, (q_fwd_indices, q_bwd_indices, q_seq_lens))
|
| 104 |
+
kv.register_spatial_cache(serialization_spatial_cache_name_kv, (kv_fwd_indices, kv_bwd_indices, kv_seq_lens))
|
| 105 |
+
else:
|
| 106 |
+
kv_fwd_indices, kv_bwd_indices, kv_seq_lens = serialization_spatial_cache_kv
|
| 107 |
+
q_fwd_indices, q_bwd_indices, q_seq_lens = serialization_spatial_cache_q
|
| 108 |
+
|
| 109 |
+
M_q, T_q, H_q, C_q = q_fwd_indices.shape[0], q.feats.shape[0], q.feats.shape[2], q.feats.shape[3]
|
| 110 |
+
M_kv, T_kv, H_kv, C_kv = kv_fwd_indices.shape[0], kv.feats.shape[0], kv.feats.shape[2], kv.feats.shape[3]
|
| 111 |
+
assert (H_q == H_kv and C_q == C_kv), \
|
| 112 |
+
f"Mismatch in shapes: q ({M_q}, {T_q}, {H_q}, {C_q}), kv ({M_kv}, {T_kv}, {H_kv}, {C_kv})"
|
| 113 |
+
|
| 114 |
+
q_feats = q.feats[q_fwd_indices] # [M, 1, H, C]
|
| 115 |
+
kv_feats = kv.feats[kv_fwd_indices] # [M, 2, H, C]
|
| 116 |
+
|
| 117 |
+
if ATTN == 'xformers':
|
| 118 |
+
q, k, v = q_feats[:, 0], kv_feats.unbind(dim=1) # [M, H, C]
|
| 119 |
+
q = q.unsqueeze(0) # [1, M, H, C]
|
| 120 |
+
k = k.unsqueeze(0) # [1, M, H, C]
|
| 121 |
+
v = v.unsqueeze(0) # [1, M, H, C]
|
| 122 |
+
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seq_lens, kv_seq_lens)
|
| 123 |
+
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
|
| 124 |
+
elif ATTN == 'flash_attn':
|
| 125 |
+
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seq_lens), dim=0)], dim=0).to(q.device).int()
|
| 126 |
+
cu_seqlens_k = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seq_lens), dim=0)], dim=0).to(kv.device).int()
|
| 127 |
+
out = flash_attn.flash_attn_varlen_kvpacked_func(q_feats[:, 0], kv_feats, cu_seqlens_q, cu_seqlens_k, max(q_seq_lens), max(kv_seq_lens)) # [M, H, C]
|
| 128 |
+
|
| 129 |
+
out = out[q_bwd_indices] # [T, H, C]
|
| 130 |
+
return q.replace(out)
|
| 131 |
+
|
anigen/modules/sparse/basic.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from . import BACKEND, DEBUG
|
| 5 |
+
SparseTensorData = None # Lazy import
|
| 6 |
+
|
| 7 |
+
import importlib
|
| 8 |
+
if BACKEND == 'torchsparse':
|
| 9 |
+
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
| 10 |
+
elif BACKEND == 'spconv':
|
| 11 |
+
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
'SparseTensor',
|
| 16 |
+
'sparse_batch_broadcast',
|
| 17 |
+
'sparse_batch_op',
|
| 18 |
+
'sparse_cat',
|
| 19 |
+
'sparse_unbind',
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SparseTensor:
|
| 24 |
+
"""
|
| 25 |
+
Sparse tensor with support for both torchsparse and spconv backends.
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
- feats (torch.Tensor): Features of the sparse tensor.
|
| 29 |
+
- coords (torch.Tensor): Coordinates of the sparse tensor.
|
| 30 |
+
- shape (torch.Size): Shape of the sparse tensor.
|
| 31 |
+
- layout (List[slice]): Layout of the sparse tensor for each batch
|
| 32 |
+
- data (SparseTensorData): Sparse tensor data used for convolusion
|
| 33 |
+
|
| 34 |
+
NOTE:
|
| 35 |
+
- Data corresponding to a same batch should be contiguous.
|
| 36 |
+
- Coords should be in [0, 1023]
|
| 37 |
+
"""
|
| 38 |
+
@overload
|
| 39 |
+
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
| 40 |
+
|
| 41 |
+
@overload
|
| 42 |
+
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
| 43 |
+
|
| 44 |
+
def __init__(self, *args, **kwargs):
|
| 45 |
+
# Lazy import of sparse tensor backend
|
| 46 |
+
global SparseTensorData
|
| 47 |
+
if SparseTensorData is None:
|
| 48 |
+
import importlib
|
| 49 |
+
if BACKEND == 'torchsparse':
|
| 50 |
+
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
| 51 |
+
elif BACKEND == 'spconv':
|
| 52 |
+
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
| 53 |
+
|
| 54 |
+
method_id = 0
|
| 55 |
+
if len(args) != 0:
|
| 56 |
+
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
| 57 |
+
else:
|
| 58 |
+
method_id = 1 if 'data' in kwargs else 0
|
| 59 |
+
|
| 60 |
+
if method_id == 0:
|
| 61 |
+
feats, coords, shape, layout = args + (None,) * (4 - len(args))
|
| 62 |
+
if 'feats' in kwargs:
|
| 63 |
+
feats = kwargs['feats']
|
| 64 |
+
del kwargs['feats']
|
| 65 |
+
if 'coords' in kwargs:
|
| 66 |
+
coords = kwargs['coords']
|
| 67 |
+
del kwargs['coords']
|
| 68 |
+
if 'shape' in kwargs:
|
| 69 |
+
shape = kwargs['shape']
|
| 70 |
+
del kwargs['shape']
|
| 71 |
+
if 'layout' in kwargs:
|
| 72 |
+
layout = kwargs['layout']
|
| 73 |
+
del kwargs['layout']
|
| 74 |
+
|
| 75 |
+
if shape is None:
|
| 76 |
+
shape = self.__cal_shape(feats, coords)
|
| 77 |
+
if layout is None:
|
| 78 |
+
layout = self.__cal_layout(coords, shape[0])
|
| 79 |
+
if BACKEND == 'torchsparse':
|
| 80 |
+
self.data = SparseTensorData(feats, coords, **kwargs)
|
| 81 |
+
elif BACKEND == 'spconv':
|
| 82 |
+
spatial_shape = list(coords.max(0)[0] + 1)[1:]
|
| 83 |
+
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
|
| 84 |
+
self.data._features = feats
|
| 85 |
+
elif method_id == 1:
|
| 86 |
+
data, shape, layout = args + (None,) * (3 - len(args))
|
| 87 |
+
if 'data' in kwargs:
|
| 88 |
+
data = kwargs['data']
|
| 89 |
+
del kwargs['data']
|
| 90 |
+
if 'shape' in kwargs:
|
| 91 |
+
shape = kwargs['shape']
|
| 92 |
+
del kwargs['shape']
|
| 93 |
+
if 'layout' in kwargs:
|
| 94 |
+
layout = kwargs['layout']
|
| 95 |
+
del kwargs['layout']
|
| 96 |
+
|
| 97 |
+
self.data = data
|
| 98 |
+
if shape is None:
|
| 99 |
+
shape = self.__cal_shape(self.feats, self.coords)
|
| 100 |
+
if layout is None:
|
| 101 |
+
layout = self.__cal_layout(self.coords, shape[0])
|
| 102 |
+
|
| 103 |
+
self._shape = shape
|
| 104 |
+
self._layout = layout
|
| 105 |
+
self._scale = kwargs.get('scale', (1, 1, 1))
|
| 106 |
+
self._spatial_cache = kwargs.get('spatial_cache', {})
|
| 107 |
+
|
| 108 |
+
if DEBUG:
|
| 109 |
+
try:
|
| 110 |
+
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
|
| 111 |
+
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
|
| 112 |
+
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
|
| 113 |
+
for i in range(self.shape[0]):
|
| 114 |
+
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print('Debugging information:')
|
| 117 |
+
print(f"- Shape: {self.shape}")
|
| 118 |
+
print(f"- Layout: {self.layout}")
|
| 119 |
+
print(f"- Scale: {self._scale}")
|
| 120 |
+
print(f"- Coords: {self.coords}")
|
| 121 |
+
raise e
|
| 122 |
+
|
| 123 |
+
def __cal_shape(self, feats, coords):
|
| 124 |
+
shape = []
|
| 125 |
+
shape.append(coords[:, 0].max().item() + 1)
|
| 126 |
+
shape.extend([*feats.shape[1:]])
|
| 127 |
+
return torch.Size(shape)
|
| 128 |
+
|
| 129 |
+
def __cal_layout(self, coords, batch_size):
|
| 130 |
+
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
|
| 131 |
+
offset = torch.cumsum(seq_len, dim=0)
|
| 132 |
+
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
|
| 133 |
+
return layout
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def shape(self) -> torch.Size:
|
| 137 |
+
return self._shape
|
| 138 |
+
|
| 139 |
+
def dim(self) -> int:
|
| 140 |
+
return len(self.shape)
|
| 141 |
+
|
| 142 |
+
@property
|
| 143 |
+
def layout(self) -> List[slice]:
|
| 144 |
+
return self._layout
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def feats(self) -> torch.Tensor:
|
| 148 |
+
if BACKEND == 'torchsparse':
|
| 149 |
+
return self.data.F
|
| 150 |
+
elif BACKEND == 'spconv':
|
| 151 |
+
return self.data.features
|
| 152 |
+
|
| 153 |
+
@feats.setter
|
| 154 |
+
def feats(self, value: torch.Tensor):
|
| 155 |
+
if BACKEND == 'torchsparse':
|
| 156 |
+
self.data.F = value
|
| 157 |
+
elif BACKEND == 'spconv':
|
| 158 |
+
self.data.features = value
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def coords(self) -> torch.Tensor:
|
| 162 |
+
if BACKEND == 'torchsparse':
|
| 163 |
+
return self.data.C
|
| 164 |
+
elif BACKEND == 'spconv':
|
| 165 |
+
return self.data.indices
|
| 166 |
+
|
| 167 |
+
@coords.setter
|
| 168 |
+
def coords(self, value: torch.Tensor):
|
| 169 |
+
if BACKEND == 'torchsparse':
|
| 170 |
+
self.data.C = value
|
| 171 |
+
elif BACKEND == 'spconv':
|
| 172 |
+
self.data.indices = value
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def dtype(self):
|
| 176 |
+
return self.feats.dtype
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def device(self):
|
| 180 |
+
return self.feats.device
|
| 181 |
+
|
| 182 |
+
@overload
|
| 183 |
+
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
|
| 184 |
+
|
| 185 |
+
@overload
|
| 186 |
+
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
|
| 187 |
+
|
| 188 |
+
def to(self, *args, **kwargs) -> 'SparseTensor':
|
| 189 |
+
device = None
|
| 190 |
+
dtype = None
|
| 191 |
+
if len(args) == 2:
|
| 192 |
+
device, dtype = args
|
| 193 |
+
elif len(args) == 1:
|
| 194 |
+
if isinstance(args[0], torch.dtype):
|
| 195 |
+
dtype = args[0]
|
| 196 |
+
else:
|
| 197 |
+
device = args[0]
|
| 198 |
+
if 'dtype' in kwargs:
|
| 199 |
+
assert dtype is None, "to() received multiple values for argument 'dtype'"
|
| 200 |
+
dtype = kwargs['dtype']
|
| 201 |
+
if 'device' in kwargs:
|
| 202 |
+
assert device is None, "to() received multiple values for argument 'device'"
|
| 203 |
+
device = kwargs['device']
|
| 204 |
+
|
| 205 |
+
new_feats = self.feats.to(device=device, dtype=dtype)
|
| 206 |
+
new_coords = self.coords.to(device=device)
|
| 207 |
+
return self.replace(new_feats, new_coords)
|
| 208 |
+
|
| 209 |
+
def type(self, dtype):
|
| 210 |
+
new_feats = self.feats.type(dtype)
|
| 211 |
+
return self.replace(new_feats)
|
| 212 |
+
|
| 213 |
+
def cpu(self) -> 'SparseTensor':
|
| 214 |
+
new_feats = self.feats.cpu()
|
| 215 |
+
new_coords = self.coords.cpu()
|
| 216 |
+
return self.replace(new_feats, new_coords)
|
| 217 |
+
|
| 218 |
+
def cuda(self) -> 'SparseTensor':
|
| 219 |
+
new_feats = self.feats.cuda()
|
| 220 |
+
new_coords = self.coords.cuda()
|
| 221 |
+
return self.replace(new_feats, new_coords)
|
| 222 |
+
|
| 223 |
+
def half(self) -> 'SparseTensor':
|
| 224 |
+
new_feats = self.feats.half()
|
| 225 |
+
return self.replace(new_feats)
|
| 226 |
+
|
| 227 |
+
def float(self) -> 'SparseTensor':
|
| 228 |
+
new_feats = self.feats.float()
|
| 229 |
+
return self.replace(new_feats)
|
| 230 |
+
|
| 231 |
+
def detach(self) -> 'SparseTensor':
|
| 232 |
+
new_coords = self.coords.detach()
|
| 233 |
+
new_feats = self.feats.detach()
|
| 234 |
+
return self.replace(new_feats, new_coords)
|
| 235 |
+
|
| 236 |
+
def dense(self) -> torch.Tensor:
|
| 237 |
+
if BACKEND == 'torchsparse':
|
| 238 |
+
return self.data.dense()
|
| 239 |
+
elif BACKEND == 'spconv':
|
| 240 |
+
return self.data.dense()
|
| 241 |
+
|
| 242 |
+
def reshape(self, *shape) -> 'SparseTensor':
|
| 243 |
+
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
| 244 |
+
return self.replace(new_feats)
|
| 245 |
+
|
| 246 |
+
def unbind(self, dim: int) -> List['SparseTensor']:
|
| 247 |
+
return sparse_unbind(self, dim)
|
| 248 |
+
|
| 249 |
+
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
| 250 |
+
new_shape = [self.shape[0]]
|
| 251 |
+
new_shape.extend(feats.shape[1:])
|
| 252 |
+
if BACKEND == 'torchsparse':
|
| 253 |
+
new_data = SparseTensorData(
|
| 254 |
+
feats=feats,
|
| 255 |
+
coords=self.data.coords if coords is None else coords,
|
| 256 |
+
stride=self.data.stride,
|
| 257 |
+
spatial_range=self.data.spatial_range,
|
| 258 |
+
)
|
| 259 |
+
new_data._caches = self.data._caches
|
| 260 |
+
elif BACKEND == 'spconv':
|
| 261 |
+
new_data = SparseTensorData(
|
| 262 |
+
self.data.features.reshape(self.data.features.shape[0], -1),
|
| 263 |
+
self.data.indices,
|
| 264 |
+
self.data.spatial_shape,
|
| 265 |
+
self.data.batch_size,
|
| 266 |
+
self.data.grid,
|
| 267 |
+
self.data.voxel_num,
|
| 268 |
+
self.data.indice_dict
|
| 269 |
+
)
|
| 270 |
+
new_data._features = feats
|
| 271 |
+
new_data.benchmark = self.data.benchmark
|
| 272 |
+
new_data.benchmark_record = self.data.benchmark_record
|
| 273 |
+
new_data.thrust_allocator = self.data.thrust_allocator
|
| 274 |
+
new_data._timer = self.data._timer
|
| 275 |
+
new_data.force_algo = self.data.force_algo
|
| 276 |
+
new_data.int8_scale = self.data.int8_scale
|
| 277 |
+
if coords is not None:
|
| 278 |
+
new_data.indices = coords
|
| 279 |
+
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
|
| 280 |
+
return new_tensor
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
| 284 |
+
N, C = dim
|
| 285 |
+
x = torch.arange(aabb[0], aabb[3] + 1)
|
| 286 |
+
y = torch.arange(aabb[1], aabb[4] + 1)
|
| 287 |
+
z = torch.arange(aabb[2], aabb[5] + 1)
|
| 288 |
+
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
| 289 |
+
coords = torch.cat([
|
| 290 |
+
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
| 291 |
+
coords.repeat(N, 1),
|
| 292 |
+
], dim=1).to(dtype=torch.int32, device=device)
|
| 293 |
+
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
| 294 |
+
return SparseTensor(feats=feats, coords=coords)
|
| 295 |
+
|
| 296 |
+
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
| 297 |
+
new_cache = {}
|
| 298 |
+
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
|
| 299 |
+
if k in self._spatial_cache:
|
| 300 |
+
new_cache[k] = self._spatial_cache[k]
|
| 301 |
+
if k in other._spatial_cache:
|
| 302 |
+
if k not in new_cache:
|
| 303 |
+
new_cache[k] = other._spatial_cache[k]
|
| 304 |
+
else:
|
| 305 |
+
new_cache[k].update(other._spatial_cache[k])
|
| 306 |
+
return new_cache
|
| 307 |
+
|
| 308 |
+
def __neg__(self) -> 'SparseTensor':
|
| 309 |
+
return self.replace(-self.feats)
|
| 310 |
+
|
| 311 |
+
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
|
| 312 |
+
if isinstance(other, torch.Tensor):
|
| 313 |
+
try:
|
| 314 |
+
other = torch.broadcast_to(other, self.shape)
|
| 315 |
+
other = sparse_batch_broadcast(self, other)
|
| 316 |
+
except:
|
| 317 |
+
pass
|
| 318 |
+
if isinstance(other, SparseTensor):
|
| 319 |
+
other = other.feats
|
| 320 |
+
new_feats = op(self.feats, other)
|
| 321 |
+
new_tensor = self.replace(new_feats)
|
| 322 |
+
if isinstance(other, SparseTensor):
|
| 323 |
+
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
|
| 324 |
+
return new_tensor
|
| 325 |
+
|
| 326 |
+
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 327 |
+
return self.__elemwise__(other, torch.add)
|
| 328 |
+
|
| 329 |
+
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 330 |
+
return self.__elemwise__(other, torch.add)
|
| 331 |
+
|
| 332 |
+
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 333 |
+
return self.__elemwise__(other, torch.sub)
|
| 334 |
+
|
| 335 |
+
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 336 |
+
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
|
| 337 |
+
|
| 338 |
+
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 339 |
+
return self.__elemwise__(other, torch.mul)
|
| 340 |
+
|
| 341 |
+
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 342 |
+
return self.__elemwise__(other, torch.mul)
|
| 343 |
+
|
| 344 |
+
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 345 |
+
return self.__elemwise__(other, torch.div)
|
| 346 |
+
|
| 347 |
+
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
| 348 |
+
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
|
| 349 |
+
|
| 350 |
+
def __getitem__(self, idx):
|
| 351 |
+
if isinstance(idx, int):
|
| 352 |
+
idx = [idx]
|
| 353 |
+
elif isinstance(idx, slice):
|
| 354 |
+
idx = range(*idx.indices(self.shape[0]))
|
| 355 |
+
elif isinstance(idx, torch.Tensor):
|
| 356 |
+
if idx.dtype == torch.bool:
|
| 357 |
+
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
| 358 |
+
idx = idx.nonzero().squeeze(1)
|
| 359 |
+
elif idx.dtype in [torch.int32, torch.int64]:
|
| 360 |
+
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(f"Unknown index type: {idx.dtype}")
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError(f"Unknown index type: {type(idx)}")
|
| 365 |
+
|
| 366 |
+
coords = []
|
| 367 |
+
feats = []
|
| 368 |
+
for new_idx, old_idx in enumerate(idx):
|
| 369 |
+
coords.append(self.coords[self.layout[old_idx]].clone())
|
| 370 |
+
coords[-1][:, 0] = new_idx
|
| 371 |
+
feats.append(self.feats[self.layout[old_idx]])
|
| 372 |
+
coords = torch.cat(coords, dim=0).contiguous()
|
| 373 |
+
feats = torch.cat(feats, dim=0).contiguous()
|
| 374 |
+
return SparseTensor(feats=feats, coords=coords)
|
| 375 |
+
|
| 376 |
+
def register_spatial_cache(self, key, value) -> None:
|
| 377 |
+
"""
|
| 378 |
+
Register a spatial cache.
|
| 379 |
+
The spatial cache can be any thing you want to cache.
|
| 380 |
+
The registery and retrieval of the cache is based on current scale.
|
| 381 |
+
"""
|
| 382 |
+
scale_key = str(self._scale)
|
| 383 |
+
if scale_key not in self._spatial_cache:
|
| 384 |
+
self._spatial_cache[scale_key] = {}
|
| 385 |
+
self._spatial_cache[scale_key][key] = value
|
| 386 |
+
|
| 387 |
+
def get_spatial_cache(self, key=None):
|
| 388 |
+
"""
|
| 389 |
+
Get a spatial cache.
|
| 390 |
+
"""
|
| 391 |
+
scale_key = str(self._scale)
|
| 392 |
+
cur_scale_cache = self._spatial_cache.get(scale_key, {})
|
| 393 |
+
if key is None:
|
| 394 |
+
return cur_scale_cache
|
| 395 |
+
return cur_scale_cache.get(key, None)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
|
| 399 |
+
"""
|
| 400 |
+
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
input (torch.Tensor): 1D tensor to broadcast.
|
| 404 |
+
target (SparseTensor): Sparse tensor to broadcast to.
|
| 405 |
+
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
| 406 |
+
"""
|
| 407 |
+
coords, feats = input.coords, input.feats
|
| 408 |
+
broadcasted = torch.zeros_like(feats)
|
| 409 |
+
for k in range(input.shape[0]):
|
| 410 |
+
broadcasted[input.layout[k]] = other[k]
|
| 411 |
+
return broadcasted
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
|
| 415 |
+
"""
|
| 416 |
+
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
input (torch.Tensor): 1D tensor to broadcast.
|
| 420 |
+
target (SparseTensor): Sparse tensor to broadcast to.
|
| 421 |
+
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
| 422 |
+
"""
|
| 423 |
+
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
|
| 427 |
+
"""
|
| 428 |
+
Concatenate a list of sparse tensors.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
|
| 432 |
+
"""
|
| 433 |
+
if dim == 0:
|
| 434 |
+
start = 0
|
| 435 |
+
coords = []
|
| 436 |
+
for input in inputs:
|
| 437 |
+
coords.append(input.coords.clone())
|
| 438 |
+
coords[-1][:, 0] += start
|
| 439 |
+
start += input.shape[0]
|
| 440 |
+
coords = torch.cat(coords, dim=0)
|
| 441 |
+
feats = torch.cat([input.feats for input in inputs], dim=0)
|
| 442 |
+
output = SparseTensor(
|
| 443 |
+
coords=coords,
|
| 444 |
+
feats=feats,
|
| 445 |
+
)
|
| 446 |
+
else:
|
| 447 |
+
feats = torch.cat([input.feats for input in inputs], dim=dim)
|
| 448 |
+
output = inputs[0].replace(feats)
|
| 449 |
+
|
| 450 |
+
return output
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
| 454 |
+
"""
|
| 455 |
+
Unbind a sparse tensor along a dimension.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
input (SparseTensor): Sparse tensor to unbind.
|
| 459 |
+
dim (int): Dimension to unbind.
|
| 460 |
+
"""
|
| 461 |
+
if dim == 0:
|
| 462 |
+
return [input[i] for i in range(input.shape[0])]
|
| 463 |
+
else:
|
| 464 |
+
feats = input.feats.unbind(dim)
|
| 465 |
+
return [input.replace(f) for f in feats]
|
anigen/modules/sparse/conv/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .. import BACKEND
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
|
| 5 |
+
|
| 6 |
+
def __from_env():
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
global SPCONV_ALGO
|
| 10 |
+
env_spconv_algo = os.environ.get('SPCONV_ALGO')
|
| 11 |
+
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
|
| 12 |
+
SPCONV_ALGO = env_spconv_algo
|
| 13 |
+
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__from_env()
|
| 17 |
+
|
| 18 |
+
if BACKEND == 'torchsparse':
|
| 19 |
+
from .conv_torchsparse import *
|
| 20 |
+
elif BACKEND == 'spconv':
|
| 21 |
+
from .conv_spconv import *
|
anigen/modules/sparse/conv/conv_spconv.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .. import SparseTensor
|
| 4 |
+
from .. import DEBUG
|
| 5 |
+
from . import SPCONV_ALGO
|
| 6 |
+
|
| 7 |
+
class SparseConv3d(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None):
|
| 9 |
+
super(SparseConv3d, self).__init__()
|
| 10 |
+
if 'spconv' not in globals():
|
| 11 |
+
import spconv.pytorch as spconv
|
| 12 |
+
algo = None
|
| 13 |
+
if SPCONV_ALGO == 'native':
|
| 14 |
+
algo = spconv.ConvAlgo.Native
|
| 15 |
+
elif SPCONV_ALGO == 'implicit_gemm':
|
| 16 |
+
algo = spconv.ConvAlgo.MaskImplicitGemm
|
| 17 |
+
if stride == 1 and (padding is None):
|
| 18 |
+
self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo)
|
| 19 |
+
else:
|
| 20 |
+
self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo)
|
| 21 |
+
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
|
| 22 |
+
self.padding = padding
|
| 23 |
+
|
| 24 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 25 |
+
spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None)
|
| 26 |
+
new_data = self.conv(x.data)
|
| 27 |
+
new_shape = [x.shape[0], self.conv.out_channels]
|
| 28 |
+
new_layout = None if spatial_changed else x.layout
|
| 29 |
+
|
| 30 |
+
if spatial_changed and (x.shape[0] != 1):
|
| 31 |
+
# spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords
|
| 32 |
+
fwd = new_data.indices[:, 0].argsort()
|
| 33 |
+
bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device))
|
| 34 |
+
sorted_feats = new_data.features[fwd]
|
| 35 |
+
sorted_coords = new_data.indices[fwd]
|
| 36 |
+
unsorted_data = new_data
|
| 37 |
+
new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore
|
| 38 |
+
|
| 39 |
+
out = SparseTensor(
|
| 40 |
+
new_data, shape=torch.Size(new_shape), layout=new_layout,
|
| 41 |
+
scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]),
|
| 42 |
+
spatial_cache=x._spatial_cache,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if spatial_changed and (x.shape[0] != 1):
|
| 46 |
+
out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data)
|
| 47 |
+
out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd)
|
| 48 |
+
|
| 49 |
+
return out
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SparseInverseConv3d(nn.Module):
|
| 53 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
|
| 54 |
+
super(SparseInverseConv3d, self).__init__()
|
| 55 |
+
if 'spconv' not in globals():
|
| 56 |
+
import spconv.pytorch as spconv
|
| 57 |
+
self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key)
|
| 58 |
+
self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 61 |
+
spatial_changed = any(s != 1 for s in self.stride)
|
| 62 |
+
if spatial_changed:
|
| 63 |
+
# recover the original spconv order
|
| 64 |
+
data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data')
|
| 65 |
+
bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd')
|
| 66 |
+
data = data.replace_feature(x.feats[bwd])
|
| 67 |
+
if DEBUG:
|
| 68 |
+
assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed'
|
| 69 |
+
else:
|
| 70 |
+
data = x.data
|
| 71 |
+
|
| 72 |
+
new_data = self.conv(data)
|
| 73 |
+
new_shape = [x.shape[0], self.conv.out_channels]
|
| 74 |
+
new_layout = None if spatial_changed else x.layout
|
| 75 |
+
out = SparseTensor(
|
| 76 |
+
new_data, shape=torch.Size(new_shape), layout=new_layout,
|
| 77 |
+
scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]),
|
| 78 |
+
spatial_cache=x._spatial_cache,
|
| 79 |
+
)
|
| 80 |
+
return out
|
anigen/modules/sparse/conv/conv_torchsparse.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .. import SparseTensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SparseConv3d(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
|
| 8 |
+
super(SparseConv3d, self).__init__()
|
| 9 |
+
if 'torchsparse' not in globals():
|
| 10 |
+
import torchsparse
|
| 11 |
+
self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias)
|
| 12 |
+
|
| 13 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 14 |
+
out = self.conv(x.data)
|
| 15 |
+
new_shape = [x.shape[0], self.conv.out_channels]
|
| 16 |
+
out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
|
| 17 |
+
out._spatial_cache = x._spatial_cache
|
| 18 |
+
out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)])
|
| 19 |
+
return out
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SparseInverseConv3d(nn.Module):
|
| 23 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None):
|
| 24 |
+
super(SparseInverseConv3d, self).__init__()
|
| 25 |
+
if 'torchsparse' not in globals():
|
| 26 |
+
import torchsparse
|
| 27 |
+
self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 30 |
+
out = self.conv(x.data)
|
| 31 |
+
new_shape = [x.shape[0], self.conv.out_channels]
|
| 32 |
+
out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None)
|
| 33 |
+
out._spatial_cache = x._spatial_cache
|
| 34 |
+
out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)])
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
anigen/modules/sparse/linear.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from . import SparseTensor
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'SparseLinear'
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SparseLinear(nn.Linear):
|
| 11 |
+
def __init__(self, in_features, out_features, bias=True):
|
| 12 |
+
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
| 13 |
+
|
| 14 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 15 |
+
return input.replace(super().forward(input.feats))
|
anigen/modules/sparse/nonlinearity.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from . import SparseTensor
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'SparseReLU',
|
| 7 |
+
'SparseSiLU',
|
| 8 |
+
'SparseGELU',
|
| 9 |
+
'SparseActivation'
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SparseReLU(nn.ReLU):
|
| 14 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 15 |
+
return input.replace(super().forward(input.feats))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SparseSiLU(nn.SiLU):
|
| 19 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 20 |
+
return input.replace(super().forward(input.feats))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SparseGELU(nn.GELU):
|
| 24 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 25 |
+
return input.replace(super().forward(input.feats))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SparseActivation(nn.Module):
|
| 29 |
+
def __init__(self, activation: nn.Module):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.activation = activation
|
| 32 |
+
|
| 33 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 34 |
+
return input.replace(self.activation(input.feats))
|
| 35 |
+
|
anigen/modules/sparse/norm.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from . import SparseTensor
|
| 4 |
+
from . import DEBUG
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'SparseGroupNorm',
|
| 8 |
+
'SparseLayerNorm',
|
| 9 |
+
'SparseGroupNorm32',
|
| 10 |
+
'SparseLayerNorm32',
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SparseGroupNorm(nn.GroupNorm):
|
| 15 |
+
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
| 16 |
+
super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
| 17 |
+
|
| 18 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 19 |
+
nfeats = torch.zeros_like(input.feats)
|
| 20 |
+
for k in range(input.shape[0]):
|
| 21 |
+
if DEBUG:
|
| 22 |
+
assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
|
| 23 |
+
bfeats = input.feats[input.layout[k]]
|
| 24 |
+
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
| 25 |
+
bfeats = super().forward(bfeats)
|
| 26 |
+
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
| 27 |
+
nfeats[input.layout[k]] = bfeats
|
| 28 |
+
return input.replace(nfeats)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SparseLayerNorm(nn.LayerNorm):
|
| 32 |
+
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
| 33 |
+
super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
|
| 34 |
+
|
| 35 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 36 |
+
nfeats = torch.zeros_like(input.feats)
|
| 37 |
+
for k in range(input.shape[0]):
|
| 38 |
+
bfeats = input.feats[input.layout[k]]
|
| 39 |
+
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
| 40 |
+
bfeats = super().forward(bfeats)
|
| 41 |
+
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
| 42 |
+
nfeats[input.layout[k]] = bfeats
|
| 43 |
+
return input.replace(nfeats)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SparseGroupNorm32(SparseGroupNorm):
|
| 47 |
+
"""
|
| 48 |
+
A GroupNorm layer that converts to float32 before the forward pass.
|
| 49 |
+
"""
|
| 50 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 51 |
+
return super().forward(x.float()).type(x.dtype)
|
| 52 |
+
|
| 53 |
+
class SparseLayerNorm32(SparseLayerNorm):
|
| 54 |
+
"""
|
| 55 |
+
A LayerNorm layer that converts to float32 before the forward pass.
|
| 56 |
+
"""
|
| 57 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 58 |
+
return super().forward(x.float()).type(x.dtype)
|
anigen/modules/sparse/spatial.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from . import SparseTensor
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
'SparseDownsample',
|
| 8 |
+
'SparseUpsample',
|
| 9 |
+
'SparseSubdivide'
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SparseDownsample(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Downsample a sparse tensor by a factor of `factor`.
|
| 16 |
+
Implemented as average pooling.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
|
| 19 |
+
super(SparseDownsample, self).__init__()
|
| 20 |
+
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
|
| 21 |
+
|
| 22 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 23 |
+
DIM = input.coords.shape[-1] - 1
|
| 24 |
+
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
|
| 25 |
+
assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
|
| 26 |
+
|
| 27 |
+
coord = list(input.coords.unbind(dim=-1))
|
| 28 |
+
for i, f in enumerate(factor):
|
| 29 |
+
coord[i+1] = coord[i+1] // f
|
| 30 |
+
|
| 31 |
+
MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
|
| 32 |
+
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
|
| 33 |
+
code = sum([c * o for c, o in zip(coord, OFFSET)])
|
| 34 |
+
code, idx = code.unique(return_inverse=True)
|
| 35 |
+
|
| 36 |
+
new_feats = torch.scatter_reduce(
|
| 37 |
+
torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
|
| 38 |
+
dim=0,
|
| 39 |
+
index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
|
| 40 |
+
src=input.feats,
|
| 41 |
+
reduce='mean'
|
| 42 |
+
)
|
| 43 |
+
new_coords = torch.stack(
|
| 44 |
+
[code // OFFSET[0]] +
|
| 45 |
+
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
|
| 46 |
+
dim=-1
|
| 47 |
+
)
|
| 48 |
+
out = SparseTensor(new_feats, new_coords, input.shape,)
|
| 49 |
+
out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
|
| 50 |
+
out._spatial_cache = input._spatial_cache
|
| 51 |
+
|
| 52 |
+
out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
|
| 53 |
+
out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
|
| 54 |
+
out.register_spatial_cache(f'upsample_{factor}_idx', idx)
|
| 55 |
+
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SparseUpsample(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Upsample a sparse tensor by a factor of `factor`.
|
| 62 |
+
Implemented as nearest neighbor interpolation.
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
|
| 65 |
+
super(SparseUpsample, self).__init__()
|
| 66 |
+
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
|
| 67 |
+
|
| 68 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 69 |
+
DIM = input.coords.shape[-1] - 1
|
| 70 |
+
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
|
| 71 |
+
assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
|
| 72 |
+
|
| 73 |
+
new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
|
| 74 |
+
new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
|
| 75 |
+
idx = input.get_spatial_cache(f'upsample_{factor}_idx')
|
| 76 |
+
if any([x is None for x in [new_coords, new_layout, idx]]):
|
| 77 |
+
raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
|
| 78 |
+
new_feats = input.feats[idx]
|
| 79 |
+
out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
|
| 80 |
+
out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
|
| 81 |
+
out._spatial_cache = input._spatial_cache
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
class SparseSubdivide(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
Upsample a sparse tensor by a factor of `factor`.
|
| 87 |
+
Implemented as nearest neighbor interpolation.
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super(SparseSubdivide, self).__init__()
|
| 91 |
+
|
| 92 |
+
def forward(self, input: SparseTensor) -> SparseTensor:
|
| 93 |
+
DIM = input.coords.shape[-1] - 1
|
| 94 |
+
# upsample scale=2^DIM
|
| 95 |
+
n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
|
| 96 |
+
n_coords = torch.nonzero(n_cube)
|
| 97 |
+
n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
|
| 98 |
+
factor = n_coords.shape[0]
|
| 99 |
+
assert factor == 2 ** DIM
|
| 100 |
+
# print(n_coords.shape)
|
| 101 |
+
new_coords = input.coords.clone()
|
| 102 |
+
new_coords[:, 1:] *= 2
|
| 103 |
+
new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
|
| 104 |
+
|
| 105 |
+
new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
|
| 106 |
+
out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
|
| 107 |
+
out._scale = input._scale * 2
|
| 108 |
+
out._spatial_cache = input._spatial_cache
|
| 109 |
+
return out
|
| 110 |
+
|
anigen/modules/sparse/transformer/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .blocks import *
|
| 2 |
+
from .modulated import *
|
| 3 |
+
from .anigen_modulated import *
|
anigen/modules/sparse/transformer/anigen_modulated.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..basic import SparseTensor
|
| 5 |
+
from ..attention import SparseMultiHeadAttention, SerializeMode
|
| 6 |
+
from ...norm import LayerNorm32
|
| 7 |
+
from .blocks import SparseFeedForwardNet
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AniGenModulatedSparseTransformerCrossBlock(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
AniGen Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
channels: int,
|
| 17 |
+
channels_skl: int,
|
| 18 |
+
ctx_channels: int,
|
| 19 |
+
num_heads: int,
|
| 20 |
+
num_heads_skl: int,
|
| 21 |
+
mlp_ratio: float = 4.0,
|
| 22 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 23 |
+
window_size: Optional[int] = None,
|
| 24 |
+
shift_sequence: Optional[int] = None,
|
| 25 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 26 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 27 |
+
use_checkpoint: bool = False,
|
| 28 |
+
use_rope: bool = False,
|
| 29 |
+
qk_rms_norm: bool = False,
|
| 30 |
+
qk_rms_norm_cross: bool = False,
|
| 31 |
+
qkv_bias: bool = True,
|
| 32 |
+
share_mod: bool = False,
|
| 33 |
+
|
| 34 |
+
):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.use_checkpoint = use_checkpoint
|
| 37 |
+
self.share_mod = share_mod
|
| 38 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 39 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
| 40 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 41 |
+
self.norm1_skl = LayerNorm32(channels_skl, elementwise_affine=False, eps=1e-6)
|
| 42 |
+
self.norm2_skl = LayerNorm32(channels_skl, elementwise_affine=True, eps=1e-6)
|
| 43 |
+
self.norm3_skl = LayerNorm32(channels_skl, elementwise_affine=False, eps=1e-6)
|
| 44 |
+
self.attn = SparseMultiHeadAttention(
|
| 45 |
+
channels,
|
| 46 |
+
ctx_channels=channels_skl,
|
| 47 |
+
num_heads=num_heads,
|
| 48 |
+
type="cross",
|
| 49 |
+
attn_mode=attn_mode,
|
| 50 |
+
window_size=window_size,
|
| 51 |
+
shift_sequence=shift_sequence,
|
| 52 |
+
shift_window=shift_window,
|
| 53 |
+
serialize_mode=serialize_mode,
|
| 54 |
+
qkv_bias=qkv_bias,
|
| 55 |
+
use_rope=use_rope,
|
| 56 |
+
qk_rms_norm=qk_rms_norm,
|
| 57 |
+
)
|
| 58 |
+
self.attn_skl = SparseMultiHeadAttention(
|
| 59 |
+
channels_skl,
|
| 60 |
+
ctx_channels=channels,
|
| 61 |
+
num_heads=num_heads_skl,
|
| 62 |
+
type="cross",
|
| 63 |
+
attn_mode=attn_mode,
|
| 64 |
+
window_size=window_size,
|
| 65 |
+
shift_sequence=shift_sequence,
|
| 66 |
+
shift_window=shift_window,
|
| 67 |
+
serialize_mode=serialize_mode,
|
| 68 |
+
qkv_bias=qkv_bias,
|
| 69 |
+
use_rope=use_rope,
|
| 70 |
+
qk_rms_norm=qk_rms_norm,
|
| 71 |
+
)
|
| 72 |
+
self.context_cross_attn = SparseMultiHeadAttention(
|
| 73 |
+
channels,
|
| 74 |
+
ctx_channels=ctx_channels,
|
| 75 |
+
num_heads=num_heads,
|
| 76 |
+
type="cross",
|
| 77 |
+
attn_mode="full",
|
| 78 |
+
qkv_bias=qkv_bias,
|
| 79 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 80 |
+
)
|
| 81 |
+
self.context_cross_attn_skl = SparseMultiHeadAttention(
|
| 82 |
+
channels_skl,
|
| 83 |
+
ctx_channels=ctx_channels,
|
| 84 |
+
num_heads=num_heads_skl,
|
| 85 |
+
type="cross",
|
| 86 |
+
attn_mode="full",
|
| 87 |
+
qkv_bias=qkv_bias,
|
| 88 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 89 |
+
)
|
| 90 |
+
self.mlp = SparseFeedForwardNet(
|
| 91 |
+
channels,
|
| 92 |
+
mlp_ratio=mlp_ratio,
|
| 93 |
+
)
|
| 94 |
+
self.mlp_skl = SparseFeedForwardNet(
|
| 95 |
+
channels_skl,
|
| 96 |
+
mlp_ratio=mlp_ratio,
|
| 97 |
+
)
|
| 98 |
+
if not share_mod:
|
| 99 |
+
self.adaLN_modulation = nn.Sequential(
|
| 100 |
+
nn.SiLU(),
|
| 101 |
+
nn.Linear(channels, 6 * channels, bias=True)
|
| 102 |
+
)
|
| 103 |
+
self.adaLN_modulation_skl = nn.Sequential(
|
| 104 |
+
nn.SiLU(),
|
| 105 |
+
nn.Linear(channels_skl, 6 * channels_skl, bias=True)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def _forward(self, x: SparseTensor, x_skl: SparseTensor, mod: torch.Tensor, mod_skl: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
| 109 |
+
if self.share_mod:
|
| 110 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
| 111 |
+
shift_msa_skl, scale_msa_skl, gate_msa_skl, shift_mlp_skl, scale_mlp_skl, gate_mlp_skl = mod_skl.chunk(6, dim=1)
|
| 112 |
+
else:
|
| 113 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
| 114 |
+
shift_msa_skl, scale_msa_skl, gate_msa_skl, shift_mlp_skl, scale_mlp_skl, gate_mlp_skl = self.adaLN_modulation_skl(mod_skl).chunk(6, dim=1)
|
| 115 |
+
# Input Norm
|
| 116 |
+
h = x.replace(self.norm1(x.feats))
|
| 117 |
+
h_skl = x_skl.replace(self.norm1_skl(x_skl.feats))
|
| 118 |
+
# AdaLN (By Time Step)
|
| 119 |
+
h = h * (1 + scale_msa) + shift_msa
|
| 120 |
+
h_skl = h_skl * (1 + scale_msa_skl) + shift_msa_skl
|
| 121 |
+
# Self Attn (Cross shape and skeleton)
|
| 122 |
+
h = self.attn(h, h_skl)
|
| 123 |
+
h_skl = self.attn_skl(h_skl, h)
|
| 124 |
+
# Gated Residual (By Time Step)
|
| 125 |
+
h = h * gate_msa
|
| 126 |
+
h_skl = h_skl * gate_msa_skl
|
| 127 |
+
x = x + h
|
| 128 |
+
x_skl = x_skl + h_skl
|
| 129 |
+
# Context Cross Attention
|
| 130 |
+
h = x.replace(self.norm2(x.feats))
|
| 131 |
+
h_skl = x_skl.replace(self.norm2_skl(x_skl.feats))
|
| 132 |
+
h = self.context_cross_attn(h, context)
|
| 133 |
+
h_skl = self.context_cross_attn_skl(h_skl, context)
|
| 134 |
+
x = x + h
|
| 135 |
+
x_skl = x_skl + h_skl
|
| 136 |
+
# Re-Centered
|
| 137 |
+
h = x.replace(self.norm3(x.feats))
|
| 138 |
+
h_skl = x_skl.replace(self.norm3_skl(x_skl.feats))
|
| 139 |
+
h = h * (1 + scale_mlp) + shift_mlp
|
| 140 |
+
h_skl = h_skl * (1 + scale_mlp_skl) + shift_mlp_skl
|
| 141 |
+
# Output MLP
|
| 142 |
+
h = self.mlp(h)
|
| 143 |
+
h_skl = self.mlp_skl(h_skl)
|
| 144 |
+
# Gated Residual (By Time Step)
|
| 145 |
+
h = h * gate_mlp
|
| 146 |
+
h_skl = h_skl * gate_mlp_skl
|
| 147 |
+
x = x + h
|
| 148 |
+
x_skl = x_skl + h_skl
|
| 149 |
+
return x, x_skl
|
| 150 |
+
|
| 151 |
+
def forward(self, x: SparseTensor, x_skl: SparseTensor, mod: torch.Tensor, mod_skl: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
| 152 |
+
if self.use_checkpoint:
|
| 153 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, x_skl, mod, mod_skl, context, use_reentrant=False)
|
| 154 |
+
else:
|
| 155 |
+
return self._forward(x, x_skl, mod, mod_skl, context)
|
anigen/modules/sparse/transformer/blocks.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..basic import SparseTensor
|
| 5 |
+
from ..linear import SparseLinear
|
| 6 |
+
from ..nonlinearity import SparseGELU
|
| 7 |
+
from ..attention import SparseMultiHeadAttention, SerializeMode
|
| 8 |
+
from ...norm import LayerNorm32
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SparseFeedForwardNet(nn.Module):
|
| 12 |
+
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.mlp = nn.Sequential(
|
| 15 |
+
SparseLinear(channels, int(channels * mlp_ratio)),
|
| 16 |
+
SparseGELU(approximate="tanh"),
|
| 17 |
+
SparseLinear(int(channels * mlp_ratio), channels),
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 21 |
+
return self.mlp(x)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SparseTransformerBlock(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Sparse Transformer block (MSA + FFN).
|
| 27 |
+
"""
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
channels: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_ratio: float = 4.0,
|
| 33 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 34 |
+
window_size: Optional[int] = None,
|
| 35 |
+
shift_sequence: Optional[int] = None,
|
| 36 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 37 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 38 |
+
use_checkpoint: bool = False,
|
| 39 |
+
use_rope: bool = False,
|
| 40 |
+
qk_rms_norm: bool = False,
|
| 41 |
+
qkv_bias: bool = True,
|
| 42 |
+
ln_affine: bool = False,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.use_checkpoint = use_checkpoint
|
| 46 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 47 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 48 |
+
self.attn = SparseMultiHeadAttention(
|
| 49 |
+
channels,
|
| 50 |
+
num_heads=num_heads,
|
| 51 |
+
attn_mode=attn_mode,
|
| 52 |
+
window_size=window_size,
|
| 53 |
+
shift_sequence=shift_sequence,
|
| 54 |
+
shift_window=shift_window,
|
| 55 |
+
serialize_mode=serialize_mode,
|
| 56 |
+
qkv_bias=qkv_bias,
|
| 57 |
+
use_rope=use_rope,
|
| 58 |
+
qk_rms_norm=qk_rms_norm,
|
| 59 |
+
)
|
| 60 |
+
self.mlp = SparseFeedForwardNet(
|
| 61 |
+
channels,
|
| 62 |
+
mlp_ratio=mlp_ratio,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def _forward(self, x: SparseTensor) -> SparseTensor:
|
| 66 |
+
# Self-attention
|
| 67 |
+
h = x.replace(self.norm1(x.feats))
|
| 68 |
+
h = self.attn(h)
|
| 69 |
+
x = x + h
|
| 70 |
+
# Feed-forward network
|
| 71 |
+
h = x.replace(self.norm2(x.feats))
|
| 72 |
+
h = self.mlp(h)
|
| 73 |
+
x = x + h
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
def forward(self, x: SparseTensor) -> SparseTensor:
|
| 77 |
+
if self.use_checkpoint:
|
| 78 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
| 79 |
+
else:
|
| 80 |
+
return self._forward(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SparseTransformerCrossBlock(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
Sparse Transformer cross-attention block (MSA + MCA + FFN).
|
| 86 |
+
"""
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
channels: int,
|
| 90 |
+
ctx_channels: int,
|
| 91 |
+
num_heads: int,
|
| 92 |
+
mlp_ratio: float = 4.0,
|
| 93 |
+
attn_mode: Literal["full", "serialized", "windowed"] = "full",
|
| 94 |
+
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
|
| 95 |
+
window_size: Optional[int] = None,
|
| 96 |
+
shift_sequence: Optional[int] = None,
|
| 97 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 98 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 99 |
+
use_checkpoint: bool = False,
|
| 100 |
+
use_rope: bool = False,
|
| 101 |
+
qk_rms_norm: bool = False,
|
| 102 |
+
qk_rms_norm_cross: bool = False,
|
| 103 |
+
qkv_bias: bool = True,
|
| 104 |
+
ln_affine: bool = False,
|
| 105 |
+
context_is_dual: bool = False,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.use_checkpoint = use_checkpoint
|
| 109 |
+
self.context_is_dual = context_is_dual
|
| 110 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 111 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 112 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 113 |
+
if context_is_dual:
|
| 114 |
+
self.norm4 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 115 |
+
self.self_attn = SparseMultiHeadAttention(
|
| 116 |
+
channels,
|
| 117 |
+
num_heads=num_heads,
|
| 118 |
+
type="self",
|
| 119 |
+
attn_mode=attn_mode,
|
| 120 |
+
window_size=window_size,
|
| 121 |
+
shift_sequence=shift_sequence,
|
| 122 |
+
shift_window=shift_window,
|
| 123 |
+
serialize_mode=serialize_mode,
|
| 124 |
+
qkv_bias=qkv_bias,
|
| 125 |
+
use_rope=use_rope,
|
| 126 |
+
qk_rms_norm=qk_rms_norm,
|
| 127 |
+
)
|
| 128 |
+
self.cross_attn = SparseMultiHeadAttention(
|
| 129 |
+
channels,
|
| 130 |
+
ctx_channels=ctx_channels,
|
| 131 |
+
num_heads=num_heads,
|
| 132 |
+
type="cross",
|
| 133 |
+
attn_mode=attn_mode_cross,
|
| 134 |
+
window_size=window_size,
|
| 135 |
+
shift_sequence=shift_sequence,
|
| 136 |
+
shift_window=shift_window,
|
| 137 |
+
serialize_mode=serialize_mode,
|
| 138 |
+
qkv_bias=qkv_bias,
|
| 139 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 140 |
+
)
|
| 141 |
+
self.mlp = SparseFeedForwardNet(
|
| 142 |
+
channels,
|
| 143 |
+
mlp_ratio=mlp_ratio,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def _forward(self, x: SparseTensor, context: SparseTensor):
|
| 147 |
+
# Self-attention
|
| 148 |
+
h = x.replace(self.norm1(x.feats))
|
| 149 |
+
h = self.self_attn(h)
|
| 150 |
+
x = x + h
|
| 151 |
+
# Cross-attention
|
| 152 |
+
h = x.replace(self.norm2(x.feats))
|
| 153 |
+
if self.context_is_dual:
|
| 154 |
+
context = context.replace(self.norm4(context.feats))
|
| 155 |
+
h = self.cross_attn(h, context)
|
| 156 |
+
x = x + h
|
| 157 |
+
# Feed-forward network
|
| 158 |
+
h = x.replace(self.norm3(x.feats))
|
| 159 |
+
h = self.mlp(h)
|
| 160 |
+
x = x + h
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
def forward(self, x: SparseTensor, context: SparseTensor):
|
| 164 |
+
if self.use_checkpoint:
|
| 165 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
| 166 |
+
else:
|
| 167 |
+
return self._forward(x, context)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class SparseTransformerMultiContextCrossBlock(nn.Module):
|
| 171 |
+
"""
|
| 172 |
+
Sparse Transformer cross-attention block (MSA + MCA + FFN).
|
| 173 |
+
"""
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
channels: int,
|
| 177 |
+
ctx_channels: List[int],
|
| 178 |
+
num_heads: int,
|
| 179 |
+
mlp_ratio: float = 4.0,
|
| 180 |
+
attn_mode: Literal["full", "serialized", "windowed"] = "full",
|
| 181 |
+
attn_mode_cross: Literal["full", "serialized", "windowed"] = "full",
|
| 182 |
+
window_size: Optional[int] = None,
|
| 183 |
+
shift_sequence: Optional[int] = None,
|
| 184 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 185 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 186 |
+
use_checkpoint: bool = False,
|
| 187 |
+
use_rope: bool = False,
|
| 188 |
+
qk_rms_norm: bool = False,
|
| 189 |
+
qk_rms_norm_cross: bool = False,
|
| 190 |
+
qkv_bias: bool = True,
|
| 191 |
+
ln_affine: bool = False,
|
| 192 |
+
cross_attn_cache_suffix: str = '',
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.context_num = len(ctx_channels)
|
| 196 |
+
self.use_checkpoint = use_checkpoint
|
| 197 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 198 |
+
if self.context_num > 0:
|
| 199 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
| 200 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 201 |
+
|
| 202 |
+
self.self_attn = SparseMultiHeadAttention(
|
| 203 |
+
channels,
|
| 204 |
+
num_heads=num_heads,
|
| 205 |
+
type="self",
|
| 206 |
+
attn_mode=attn_mode,
|
| 207 |
+
window_size=window_size,
|
| 208 |
+
shift_sequence=shift_sequence,
|
| 209 |
+
shift_window=shift_window,
|
| 210 |
+
serialize_mode=serialize_mode,
|
| 211 |
+
qkv_bias=qkv_bias,
|
| 212 |
+
use_rope=use_rope,
|
| 213 |
+
qk_rms_norm=qk_rms_norm,
|
| 214 |
+
)
|
| 215 |
+
for i in range(self.context_num):
|
| 216 |
+
setattr(self, f'cross_attn_{i}', SparseMultiHeadAttention(
|
| 217 |
+
channels,
|
| 218 |
+
ctx_channels=ctx_channels[i],
|
| 219 |
+
num_heads=num_heads,
|
| 220 |
+
type="cross",
|
| 221 |
+
attn_mode=attn_mode_cross,
|
| 222 |
+
window_size=window_size,
|
| 223 |
+
shift_sequence=shift_sequence,
|
| 224 |
+
shift_window=shift_window,
|
| 225 |
+
serialize_mode=serialize_mode,
|
| 226 |
+
qkv_bias=qkv_bias,
|
| 227 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 228 |
+
cross_attn_cache_suffix=cross_attn_cache_suffix + f'_modality_{i}'
|
| 229 |
+
))
|
| 230 |
+
setattr(self, f'ctx_norm_{i}', LayerNorm32(ctx_channels[i], elementwise_affine=True, eps=1e-6))
|
| 231 |
+
|
| 232 |
+
self.mlp = SparseFeedForwardNet(
|
| 233 |
+
channels,
|
| 234 |
+
mlp_ratio=mlp_ratio,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def _forward(self, x: SparseTensor, contexts: List[SparseTensor]) -> SparseTensor:
|
| 238 |
+
# Self-attention
|
| 239 |
+
h = x.replace(self.norm1(x.feats))
|
| 240 |
+
h = self.self_attn(h)
|
| 241 |
+
x = x + h
|
| 242 |
+
# Cross-attention
|
| 243 |
+
if self.context_num > 0 and len(contexts) > 0:
|
| 244 |
+
h_norm = x.replace(self.norm2(x.feats))
|
| 245 |
+
for i, context in enumerate(contexts):
|
| 246 |
+
context = context.replace(getattr(self, f'ctx_norm_{i}')(context.feats))
|
| 247 |
+
h = getattr(self, f'cross_attn_{i}')(h_norm, context)
|
| 248 |
+
x = x + h
|
| 249 |
+
# Feed-forward network
|
| 250 |
+
h = x.replace(self.norm3(x.feats))
|
| 251 |
+
h = self.mlp(h)
|
| 252 |
+
x = x + h
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
def forward(self, x: SparseTensor, contexts: List[SparseTensor]) -> SparseTensor:
|
| 256 |
+
if self.use_checkpoint:
|
| 257 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, contexts, use_reentrant=False)
|
| 258 |
+
else:
|
| 259 |
+
return self._forward(x, contexts)
|
anigen/modules/sparse/transformer/modulated.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..basic import SparseTensor
|
| 5 |
+
from ..attention import SparseMultiHeadAttention, SerializeMode
|
| 6 |
+
from ...norm import LayerNorm32
|
| 7 |
+
from .blocks import SparseFeedForwardNet
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ModulatedSparseTransformerBlock(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
channels: int,
|
| 17 |
+
num_heads: int,
|
| 18 |
+
mlp_ratio: float = 4.0,
|
| 19 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 20 |
+
window_size: Optional[int] = None,
|
| 21 |
+
shift_sequence: Optional[int] = None,
|
| 22 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 23 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 24 |
+
use_checkpoint: bool = False,
|
| 25 |
+
use_rope: bool = False,
|
| 26 |
+
qk_rms_norm: bool = False,
|
| 27 |
+
qkv_bias: bool = True,
|
| 28 |
+
share_mod: bool = False,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.use_checkpoint = use_checkpoint
|
| 32 |
+
self.share_mod = share_mod
|
| 33 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 34 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 35 |
+
self.attn = SparseMultiHeadAttention(
|
| 36 |
+
channels,
|
| 37 |
+
num_heads=num_heads,
|
| 38 |
+
attn_mode=attn_mode,
|
| 39 |
+
window_size=window_size,
|
| 40 |
+
shift_sequence=shift_sequence,
|
| 41 |
+
shift_window=shift_window,
|
| 42 |
+
serialize_mode=serialize_mode,
|
| 43 |
+
qkv_bias=qkv_bias,
|
| 44 |
+
use_rope=use_rope,
|
| 45 |
+
qk_rms_norm=qk_rms_norm,
|
| 46 |
+
)
|
| 47 |
+
self.mlp = SparseFeedForwardNet(
|
| 48 |
+
channels,
|
| 49 |
+
mlp_ratio=mlp_ratio,
|
| 50 |
+
)
|
| 51 |
+
if not share_mod:
|
| 52 |
+
self.adaLN_modulation = nn.Sequential(
|
| 53 |
+
nn.SiLU(),
|
| 54 |
+
nn.Linear(channels, 6 * channels, bias=True)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
| 58 |
+
if self.share_mod:
|
| 59 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
| 60 |
+
else:
|
| 61 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
| 62 |
+
h = x.replace(self.norm1(x.feats))
|
| 63 |
+
h = h * (1 + scale_msa) + shift_msa
|
| 64 |
+
h = self.attn(h)
|
| 65 |
+
h = h * gate_msa
|
| 66 |
+
x = x + h
|
| 67 |
+
h = x.replace(self.norm2(x.feats))
|
| 68 |
+
h = h * (1 + scale_mlp) + shift_mlp
|
| 69 |
+
h = self.mlp(h)
|
| 70 |
+
h = h * gate_mlp
|
| 71 |
+
x = x + h
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
| 75 |
+
if self.use_checkpoint:
|
| 76 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
| 77 |
+
else:
|
| 78 |
+
return self._forward(x, mod)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
| 84 |
+
"""
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
channels: int,
|
| 88 |
+
ctx_channels: int,
|
| 89 |
+
num_heads: int,
|
| 90 |
+
mlp_ratio: float = 4.0,
|
| 91 |
+
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
| 92 |
+
window_size: Optional[int] = None,
|
| 93 |
+
shift_sequence: Optional[int] = None,
|
| 94 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 95 |
+
serialize_mode: Optional[SerializeMode] = None,
|
| 96 |
+
use_checkpoint: bool = False,
|
| 97 |
+
use_rope: bool = False,
|
| 98 |
+
qk_rms_norm: bool = False,
|
| 99 |
+
qk_rms_norm_cross: bool = False,
|
| 100 |
+
qkv_bias: bool = True,
|
| 101 |
+
share_mod: bool = False,
|
| 102 |
+
norm_for_context: bool = False,
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.use_checkpoint = use_checkpoint
|
| 106 |
+
self.share_mod = share_mod
|
| 107 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 108 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
| 109 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 110 |
+
self.norm_for_context = norm_for_context
|
| 111 |
+
if self.norm_for_context:
|
| 112 |
+
self.context_norm = LayerNorm32(ctx_channels, elementwise_affine=True, eps=1e-6)
|
| 113 |
+
self.self_attn = SparseMultiHeadAttention(
|
| 114 |
+
channels,
|
| 115 |
+
num_heads=num_heads,
|
| 116 |
+
type="self",
|
| 117 |
+
attn_mode=attn_mode,
|
| 118 |
+
window_size=window_size,
|
| 119 |
+
shift_sequence=shift_sequence,
|
| 120 |
+
shift_window=shift_window,
|
| 121 |
+
serialize_mode=serialize_mode,
|
| 122 |
+
qkv_bias=qkv_bias,
|
| 123 |
+
use_rope=use_rope,
|
| 124 |
+
qk_rms_norm=qk_rms_norm,
|
| 125 |
+
)
|
| 126 |
+
self.cross_attn = SparseMultiHeadAttention(
|
| 127 |
+
channels,
|
| 128 |
+
ctx_channels=ctx_channels,
|
| 129 |
+
num_heads=num_heads,
|
| 130 |
+
type="cross",
|
| 131 |
+
attn_mode="full",
|
| 132 |
+
qkv_bias=qkv_bias,
|
| 133 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 134 |
+
)
|
| 135 |
+
self.mlp = SparseFeedForwardNet(
|
| 136 |
+
channels,
|
| 137 |
+
mlp_ratio=mlp_ratio,
|
| 138 |
+
)
|
| 139 |
+
if not share_mod:
|
| 140 |
+
self.adaLN_modulation = nn.Sequential(
|
| 141 |
+
nn.SiLU(),
|
| 142 |
+
nn.Linear(channels, 6 * channels, bias=True)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
| 146 |
+
if self.share_mod:
|
| 147 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
| 148 |
+
else:
|
| 149 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
| 150 |
+
h = x.replace(self.norm1(x.feats))
|
| 151 |
+
h = h * (1 + scale_msa) + shift_msa
|
| 152 |
+
h = self.self_attn(h)
|
| 153 |
+
h = h * gate_msa
|
| 154 |
+
x = x + h
|
| 155 |
+
h = x.replace(self.norm2(x.feats))
|
| 156 |
+
if self.norm_for_context:
|
| 157 |
+
if isinstance(context, SparseTensor):
|
| 158 |
+
context = context.replace(self.context_norm(context.feats))
|
| 159 |
+
else:
|
| 160 |
+
context = self.context_norm(context)
|
| 161 |
+
h = self.cross_attn(h, context)
|
| 162 |
+
x = x + h
|
| 163 |
+
h = x.replace(self.norm3(x.feats))
|
| 164 |
+
h = h * (1 + scale_mlp) + shift_mlp
|
| 165 |
+
h = self.mlp(h)
|
| 166 |
+
h = h * gate_mlp
|
| 167 |
+
x = x + h
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
| 171 |
+
if self.use_checkpoint:
|
| 172 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
| 173 |
+
else:
|
| 174 |
+
return self._forward(x, mod, context)
|
anigen/modules/spatial.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
|
| 5 |
+
"""
|
| 6 |
+
3D pixel shuffle.
|
| 7 |
+
"""
|
| 8 |
+
B, C, H, W, D = x.shape
|
| 9 |
+
C_ = C // scale_factor**3
|
| 10 |
+
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
|
| 11 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
| 12 |
+
x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor)
|
| 13 |
+
return x
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def patchify(x: torch.Tensor, patch_size: int):
|
| 17 |
+
"""
|
| 18 |
+
Patchify a tensor.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
x (torch.Tensor): (N, C, *spatial) tensor
|
| 22 |
+
patch_size (int): Patch size
|
| 23 |
+
"""
|
| 24 |
+
DIM = x.dim() - 2
|
| 25 |
+
for d in range(2, DIM + 2):
|
| 26 |
+
assert x.shape[d] % patch_size == 0, f"Dimension {d} of input tensor must be divisible by patch size, got {x.shape[d]} and {patch_size}"
|
| 27 |
+
|
| 28 |
+
x = x.reshape(*x.shape[:2], *sum([[x.shape[d] // patch_size, patch_size] for d in range(2, DIM + 2)], []))
|
| 29 |
+
x = x.permute(0, 1, *([2 * i + 3 for i in range(DIM)] + [2 * i + 2 for i in range(DIM)]))
|
| 30 |
+
x = x.reshape(x.shape[0], x.shape[1] * (patch_size ** DIM), *(x.shape[-DIM:]))
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def unpatchify(x: torch.Tensor, patch_size: int):
|
| 35 |
+
"""
|
| 36 |
+
Unpatchify a tensor.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
x (torch.Tensor): (N, C, *spatial) tensor
|
| 40 |
+
patch_size (int): Patch size
|
| 41 |
+
"""
|
| 42 |
+
DIM = x.dim() - 2
|
| 43 |
+
assert x.shape[1] % (patch_size ** DIM) == 0, f"Second dimension of input tensor must be divisible by patch size to unpatchify, got {x.shape[1]} and {patch_size ** DIM}"
|
| 44 |
+
|
| 45 |
+
x = x.reshape(x.shape[0], x.shape[1] // (patch_size ** DIM), *([patch_size] * DIM), *(x.shape[-DIM:]))
|
| 46 |
+
x = x.permute(0, 1, *(sum([[2 + DIM + i, 2 + i] for i in range(DIM)], [])))
|
| 47 |
+
x = x.reshape(x.shape[0], x.shape[1], *[x.shape[2 + 2 * i] * patch_size for i in range(DIM)])
|
| 48 |
+
return x
|
anigen/modules/transformer/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .blocks import *
|
| 2 |
+
from .modulated import *
|
anigen/modules/transformer/blocks.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..attention import MultiHeadAttention
|
| 5 |
+
from ..norm import LayerNorm32
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AbsolutePositionEmbedder(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Embeds spatial positions into vector representations.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, channels: int, in_channels: int = 3):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.channels = channels
|
| 15 |
+
self.in_channels = in_channels
|
| 16 |
+
self.freq_dim = channels // in_channels // 2
|
| 17 |
+
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
| 18 |
+
self.freqs = 1.0 / (10000 ** self.freqs)
|
| 19 |
+
|
| 20 |
+
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
"""
|
| 22 |
+
Create sinusoidal position embeddings.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
x: a 1-D Tensor of N indices
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
an (N, D) Tensor of positional embeddings.
|
| 29 |
+
"""
|
| 30 |
+
self.freqs = self.freqs.to(x.device)
|
| 31 |
+
out = torch.outer(x, self.freqs)
|
| 32 |
+
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
|
| 33 |
+
return out
|
| 34 |
+
|
| 35 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Args:
|
| 38 |
+
x (torch.Tensor): (N, D) tensor of spatial positions
|
| 39 |
+
"""
|
| 40 |
+
N, D = x.shape
|
| 41 |
+
assert D == self.in_channels, "Input dimension must match number of input channels"
|
| 42 |
+
embed = self._sin_cos_embedding(x.reshape(-1))
|
| 43 |
+
embed = embed.reshape(N, -1)
|
| 44 |
+
if embed.shape[1] < self.channels:
|
| 45 |
+
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
|
| 46 |
+
return embed
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FeedForwardNet(nn.Module):
|
| 50 |
+
def __init__(self, channels: int, mlp_ratio: float = 4.0, out_channels: Optional[int] = None):
|
| 51 |
+
super().__init__()
|
| 52 |
+
if out_channels is None:
|
| 53 |
+
out_channels = channels
|
| 54 |
+
self.mlp = nn.Sequential(
|
| 55 |
+
nn.Linear(channels, int(channels * mlp_ratio)),
|
| 56 |
+
nn.GELU(approximate="tanh"),
|
| 57 |
+
nn.Linear(int(channels * mlp_ratio), out_channels),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
return self.mlp(x)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TransformerBlock(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Transformer block (MSA + FFN).
|
| 67 |
+
"""
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
channels: int,
|
| 71 |
+
num_heads: int,
|
| 72 |
+
out_channels: Optional[int] = None,
|
| 73 |
+
mlp_ratio: float = 4.0,
|
| 74 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 75 |
+
window_size: Optional[int] = None,
|
| 76 |
+
shift_window: Optional[int] = None,
|
| 77 |
+
use_checkpoint: bool = False,
|
| 78 |
+
use_rope: bool = False,
|
| 79 |
+
qk_rms_norm: bool = False,
|
| 80 |
+
qkv_bias: bool = True,
|
| 81 |
+
ln_affine: bool = False,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.use_checkpoint = use_checkpoint
|
| 85 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 86 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 87 |
+
self.attn = MultiHeadAttention(
|
| 88 |
+
channels,
|
| 89 |
+
num_heads=num_heads,
|
| 90 |
+
attn_mode=attn_mode,
|
| 91 |
+
window_size=window_size,
|
| 92 |
+
shift_window=shift_window,
|
| 93 |
+
qkv_bias=qkv_bias,
|
| 94 |
+
use_rope=use_rope,
|
| 95 |
+
qk_rms_norm=qk_rms_norm,
|
| 96 |
+
)
|
| 97 |
+
self.channels = channels
|
| 98 |
+
self.out_channels = out_channels if out_channels is not None else channels
|
| 99 |
+
self.mlp = FeedForwardNet(
|
| 100 |
+
self.channels,
|
| 101 |
+
out_channels=self.out_channels,
|
| 102 |
+
mlp_ratio=mlp_ratio,
|
| 103 |
+
)
|
| 104 |
+
if self.out_channels != self.channels:
|
| 105 |
+
self.res_mlp = FeedForwardNet(
|
| 106 |
+
self.channels,
|
| 107 |
+
out_channels=self.out_channels,
|
| 108 |
+
mlp_ratio=1.0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 112 |
+
h = self.norm1(x)
|
| 113 |
+
h = self.attn(h)
|
| 114 |
+
x = x + h
|
| 115 |
+
h = self.norm2(x)
|
| 116 |
+
h = self.mlp(h)
|
| 117 |
+
if self.out_channels != self.channels:
|
| 118 |
+
x = self.res_mlp(x)
|
| 119 |
+
x = x + h
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
if self.use_checkpoint:
|
| 124 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
| 125 |
+
else:
|
| 126 |
+
return self._forward(x)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SkinTransformerCrossBlock(nn.Module):
|
| 130 |
+
"""
|
| 131 |
+
Transformer block (MSA + FFN).
|
| 132 |
+
"""
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
channels: int,
|
| 136 |
+
num_heads: int,
|
| 137 |
+
out_channels: Optional[int] = None,
|
| 138 |
+
mlp_ratio: float = 4.0,
|
| 139 |
+
use_checkpoint: bool = False,
|
| 140 |
+
qkv_bias: bool = True,
|
| 141 |
+
ln_affine: bool = False,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.use_checkpoint = use_checkpoint
|
| 145 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 146 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 147 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 148 |
+
self.to_v = nn.Linear(channels, channels, bias=qkv_bias)
|
| 149 |
+
self.channels = channels
|
| 150 |
+
self.out_channels = out_channels if out_channels is not None else channels
|
| 151 |
+
self.mlp = FeedForwardNet(
|
| 152 |
+
self.channels,
|
| 153 |
+
out_channels=self.out_channels,
|
| 154 |
+
mlp_ratio=mlp_ratio,
|
| 155 |
+
)
|
| 156 |
+
self.joint_attn = MultiHeadAttention(
|
| 157 |
+
channels,
|
| 158 |
+
num_heads=num_heads,
|
| 159 |
+
type="self",
|
| 160 |
+
attn_mode="full",
|
| 161 |
+
qkv_bias=qkv_bias,
|
| 162 |
+
)
|
| 163 |
+
self.joint_mlp = FeedForwardNet(
|
| 164 |
+
self.channels,
|
| 165 |
+
out_channels=self.out_channels,
|
| 166 |
+
mlp_ratio=mlp_ratio,
|
| 167 |
+
)
|
| 168 |
+
if self.out_channels != self.channels:
|
| 169 |
+
self.res_mlp = FeedForwardNet(
|
| 170 |
+
self.channels,
|
| 171 |
+
out_channels=self.out_channels,
|
| 172 |
+
mlp_ratio=1.0,
|
| 173 |
+
)
|
| 174 |
+
self.res_joint_mlp = FeedForwardNet(
|
| 175 |
+
self.channels,
|
| 176 |
+
out_channels=self.out_channels,
|
| 177 |
+
mlp_ratio=1.0,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def _forward(self, x: torch.Tensor, j: torch.Tensor, skin: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
v = self.to_v(self.norm1(j))
|
| 182 |
+
h = skin @ v
|
| 183 |
+
x = x + h
|
| 184 |
+
h = self.norm2(x)
|
| 185 |
+
h = self.mlp(h)
|
| 186 |
+
if self.out_channels != self.channels:
|
| 187 |
+
x = self.res_mlp(x)
|
| 188 |
+
x = x + h
|
| 189 |
+
|
| 190 |
+
h_j = self.norm3(j)
|
| 191 |
+
h_j = self.joint_attn(h_j)
|
| 192 |
+
h_j = j + h_j
|
| 193 |
+
h_j = self.joint_mlp(h_j)
|
| 194 |
+
if self.out_channels != self.channels:
|
| 195 |
+
j = self.res_joint_mlp(j)
|
| 196 |
+
j = j + h_j
|
| 197 |
+
return x, j
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class TransformerCrossBlock(nn.Module):
|
| 201 |
+
"""
|
| 202 |
+
Transformer cross-attention block (MSA + MCA + FFN).
|
| 203 |
+
"""
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
channels: int,
|
| 207 |
+
ctx_channels: int,
|
| 208 |
+
num_heads: int,
|
| 209 |
+
out_channels: Optional[int] = None,
|
| 210 |
+
mlp_ratio: float = 4.0,
|
| 211 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 212 |
+
window_size: Optional[int] = None,
|
| 213 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 214 |
+
use_checkpoint: bool = False,
|
| 215 |
+
use_rope: bool = False,
|
| 216 |
+
qk_rms_norm: bool = False,
|
| 217 |
+
qk_rms_norm_cross: bool = False,
|
| 218 |
+
qkv_bias: bool = True,
|
| 219 |
+
ln_affine: bool = False,
|
| 220 |
+
x_is_query: bool = False,
|
| 221 |
+
no_self: bool = False,
|
| 222 |
+
):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.use_checkpoint = use_checkpoint
|
| 225 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6) if not no_self else nn.Identity()
|
| 226 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 227 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
| 228 |
+
if no_self:
|
| 229 |
+
self.self_attn = lambda x: 0
|
| 230 |
+
else:
|
| 231 |
+
self.self_attn = MultiHeadAttention(
|
| 232 |
+
channels,
|
| 233 |
+
num_heads=num_heads,
|
| 234 |
+
type="self",
|
| 235 |
+
attn_mode=attn_mode,
|
| 236 |
+
window_size=window_size,
|
| 237 |
+
shift_window=shift_window,
|
| 238 |
+
qkv_bias=qkv_bias,
|
| 239 |
+
use_rope=use_rope,
|
| 240 |
+
qk_rms_norm=qk_rms_norm,
|
| 241 |
+
)
|
| 242 |
+
self.cross_attn = MultiHeadAttention(
|
| 243 |
+
channels,
|
| 244 |
+
ctx_channels=ctx_channels,
|
| 245 |
+
num_heads=num_heads,
|
| 246 |
+
type="cross",
|
| 247 |
+
attn_mode="full",
|
| 248 |
+
qkv_bias=qkv_bias,
|
| 249 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 250 |
+
x_is_query=x_is_query,
|
| 251 |
+
)
|
| 252 |
+
self.channels = channels
|
| 253 |
+
self.out_channels = out_channels if out_channels is not None else channels
|
| 254 |
+
self.mlp = FeedForwardNet(
|
| 255 |
+
channels,
|
| 256 |
+
out_channels=self.out_channels,
|
| 257 |
+
mlp_ratio=mlp_ratio,
|
| 258 |
+
)
|
| 259 |
+
if self.out_channels != self.channels:
|
| 260 |
+
self.res_mlp = FeedForwardNet(
|
| 261 |
+
self.channels,
|
| 262 |
+
out_channels=self.out_channels,
|
| 263 |
+
mlp_ratio=1.0,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def _forward(self, x: torch.Tensor, context: torch.Tensor):
|
| 267 |
+
h = self.norm1(x)
|
| 268 |
+
h = self.self_attn(h)
|
| 269 |
+
x = x + h
|
| 270 |
+
h = self.norm2(x)
|
| 271 |
+
h = self.cross_attn(h, context)
|
| 272 |
+
x = x + h
|
| 273 |
+
h = self.norm3(x)
|
| 274 |
+
h = self.mlp(h)
|
| 275 |
+
if self.out_channels != self.channels:
|
| 276 |
+
x = self.res_mlp(x)
|
| 277 |
+
x = x + h
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor):
|
| 281 |
+
if self.use_checkpoint:
|
| 282 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
| 283 |
+
else:
|
| 284 |
+
return self._forward(x, context)
|
| 285 |
+
|
anigen/modules/transformer/modulated.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from ..attention import MultiHeadAttention
|
| 5 |
+
from ..norm import LayerNorm32
|
| 6 |
+
from .blocks import FeedForwardNet
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModulatedTransformerBlock(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
channels: int,
|
| 16 |
+
num_heads: int,
|
| 17 |
+
mlp_ratio: float = 4.0,
|
| 18 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 19 |
+
window_size: Optional[int] = None,
|
| 20 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 21 |
+
use_checkpoint: bool = False,
|
| 22 |
+
use_rope: bool = False,
|
| 23 |
+
qk_rms_norm: bool = False,
|
| 24 |
+
qkv_bias: bool = True,
|
| 25 |
+
share_mod: bool = False,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.use_checkpoint = use_checkpoint
|
| 29 |
+
self.share_mod = share_mod
|
| 30 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 31 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 32 |
+
self.attn = MultiHeadAttention(
|
| 33 |
+
channels,
|
| 34 |
+
num_heads=num_heads,
|
| 35 |
+
attn_mode=attn_mode,
|
| 36 |
+
window_size=window_size,
|
| 37 |
+
shift_window=shift_window,
|
| 38 |
+
qkv_bias=qkv_bias,
|
| 39 |
+
use_rope=use_rope,
|
| 40 |
+
qk_rms_norm=qk_rms_norm,
|
| 41 |
+
)
|
| 42 |
+
self.mlp = FeedForwardNet(
|
| 43 |
+
channels,
|
| 44 |
+
mlp_ratio=mlp_ratio,
|
| 45 |
+
)
|
| 46 |
+
if not share_mod:
|
| 47 |
+
self.adaLN_modulation = nn.Sequential(
|
| 48 |
+
nn.SiLU(),
|
| 49 |
+
nn.Linear(channels, 6 * channels, bias=True)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
if self.share_mod:
|
| 54 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
| 55 |
+
else:
|
| 56 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
| 57 |
+
h = self.norm1(x)
|
| 58 |
+
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
| 59 |
+
h = self.attn(h)
|
| 60 |
+
h = h * gate_msa.unsqueeze(1)
|
| 61 |
+
x = x + h
|
| 62 |
+
h = self.norm2(x)
|
| 63 |
+
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
| 64 |
+
h = self.mlp(h)
|
| 65 |
+
h = h * gate_mlp.unsqueeze(1)
|
| 66 |
+
x = x + h
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
if self.use_checkpoint:
|
| 71 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
| 72 |
+
else:
|
| 73 |
+
return self._forward(x, mod)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ModulatedTransformerCrossBlock(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
| 79 |
+
"""
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
channels: int,
|
| 83 |
+
ctx_channels: int,
|
| 84 |
+
num_heads: int,
|
| 85 |
+
mlp_ratio: float = 4.0,
|
| 86 |
+
attn_mode: Literal["full", "windowed"] = "full",
|
| 87 |
+
window_size: Optional[int] = None,
|
| 88 |
+
shift_window: Optional[Tuple[int, int, int]] = None,
|
| 89 |
+
use_checkpoint: bool = False,
|
| 90 |
+
use_rope: bool = False,
|
| 91 |
+
qk_rms_norm: bool = False,
|
| 92 |
+
qk_rms_norm_cross: bool = False,
|
| 93 |
+
qkv_bias: bool = True,
|
| 94 |
+
share_mod: bool = False,
|
| 95 |
+
|
| 96 |
+
use_lora_self: bool = False,
|
| 97 |
+
lora_rank_self: int = 4,
|
| 98 |
+
use_lora_cross: bool = False,
|
| 99 |
+
lora_rank_cross: int = 4,
|
| 100 |
+
lora_lr_rate: float = 1.0,
|
| 101 |
+
use_context_norm: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.use_checkpoint = use_checkpoint
|
| 105 |
+
self.share_mod = share_mod
|
| 106 |
+
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 107 |
+
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
| 108 |
+
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
| 109 |
+
self.use_context_norm = use_context_norm
|
| 110 |
+
if self.use_context_norm:
|
| 111 |
+
self.context_norm = LayerNorm32(ctx_channels, elementwise_affine=True, eps=1e-6)
|
| 112 |
+
self.self_attn = MultiHeadAttention(
|
| 113 |
+
channels,
|
| 114 |
+
num_heads=num_heads,
|
| 115 |
+
type="self",
|
| 116 |
+
attn_mode=attn_mode,
|
| 117 |
+
window_size=window_size,
|
| 118 |
+
shift_window=shift_window,
|
| 119 |
+
qkv_bias=qkv_bias,
|
| 120 |
+
use_rope=use_rope,
|
| 121 |
+
qk_rms_norm=qk_rms_norm,
|
| 122 |
+
use_lora=use_lora_self,
|
| 123 |
+
lora_rank=lora_rank_self,
|
| 124 |
+
lora_lr_rate=lora_lr_rate,
|
| 125 |
+
)
|
| 126 |
+
self.cross_attn = MultiHeadAttention(
|
| 127 |
+
channels,
|
| 128 |
+
ctx_channels=ctx_channels,
|
| 129 |
+
num_heads=num_heads,
|
| 130 |
+
type="cross",
|
| 131 |
+
attn_mode="full",
|
| 132 |
+
qkv_bias=qkv_bias,
|
| 133 |
+
qk_rms_norm=qk_rms_norm_cross,
|
| 134 |
+
use_lora=use_lora_cross,
|
| 135 |
+
lora_rank=lora_rank_cross,
|
| 136 |
+
lora_lr_rate=lora_lr_rate,
|
| 137 |
+
)
|
| 138 |
+
self.mlp = FeedForwardNet(
|
| 139 |
+
channels,
|
| 140 |
+
mlp_ratio=mlp_ratio,
|
| 141 |
+
)
|
| 142 |
+
if not share_mod:
|
| 143 |
+
self.adaLN_modulation = nn.Sequential(
|
| 144 |
+
nn.SiLU(),
|
| 145 |
+
nn.Linear(channels, 6 * channels, bias=True)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
| 149 |
+
if self.share_mod:
|
| 150 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
| 151 |
+
else:
|
| 152 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
| 153 |
+
h = self.norm1(x)
|
| 154 |
+
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
| 155 |
+
h = self.self_attn(h)
|
| 156 |
+
h = h * gate_msa.unsqueeze(1)
|
| 157 |
+
x = x + h
|
| 158 |
+
h = self.norm2(x)
|
| 159 |
+
if self.use_context_norm:
|
| 160 |
+
context = self.context_norm(context)
|
| 161 |
+
h = self.cross_attn(h, context)
|
| 162 |
+
x = x + h
|
| 163 |
+
h = self.norm3(x)
|
| 164 |
+
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
| 165 |
+
h = self.mlp(h)
|
| 166 |
+
h = h * gate_mlp.unsqueeze(1)
|
| 167 |
+
x = x + h
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
| 171 |
+
if self.use_checkpoint:
|
| 172 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
| 173 |
+
else:
|
| 174 |
+
return self._forward(x, mod, context)
|
| 175 |
+
|
anigen/modules/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from ..modules import sparse as sp
|
| 3 |
+
|
| 4 |
+
FP16_MODULES = (
|
| 5 |
+
nn.Conv1d,
|
| 6 |
+
nn.Conv2d,
|
| 7 |
+
nn.Conv3d,
|
| 8 |
+
nn.ConvTranspose1d,
|
| 9 |
+
nn.ConvTranspose2d,
|
| 10 |
+
nn.ConvTranspose3d,
|
| 11 |
+
nn.Linear,
|
| 12 |
+
sp.SparseConv3d,
|
| 13 |
+
sp.SparseInverseConv3d,
|
| 14 |
+
sp.SparseLinear,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def convert_module_to_f16(l):
|
| 18 |
+
"""
|
| 19 |
+
Convert primitive modules to float16.
|
| 20 |
+
"""
|
| 21 |
+
if isinstance(l, FP16_MODULES):
|
| 22 |
+
for p in l.parameters():
|
| 23 |
+
p.data = p.data.half()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert_module_to_f32(l):
|
| 27 |
+
"""
|
| 28 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
| 29 |
+
"""
|
| 30 |
+
if isinstance(l, FP16_MODULES):
|
| 31 |
+
for p in l.parameters():
|
| 32 |
+
p.data = p.data.float()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def zero_module(module):
|
| 36 |
+
"""
|
| 37 |
+
Zero out the parameters of a module and return it.
|
| 38 |
+
"""
|
| 39 |
+
for p in module.parameters():
|
| 40 |
+
p.detach().zero_()
|
| 41 |
+
return module
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def scale_module(module, scale):
|
| 45 |
+
"""
|
| 46 |
+
Scale the parameters of a module and return it.
|
| 47 |
+
"""
|
| 48 |
+
for p in module.parameters():
|
| 49 |
+
p.detach().mul_(scale)
|
| 50 |
+
return module
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def modulate(x, shift, scale):
|
| 54 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|