Commit
·
e86746e
0
Parent(s):
Initial commit: FloodDiffusionTiny - Tiny text-to-motion model with UMT5-Base
Browse files- .gitattributes +36 -0
- .gitignore +54 -0
- README.md +177 -0
- __init__.py +11 -0
- config.json +11 -0
- generate_ldf.py +139 -0
- hf_pipeline.py +282 -0
- ldf.yaml +44 -0
- ldf_models/__init__.py +0 -0
- ldf_models/diffusion_forcing_wan_tiny.py +943 -0
- ldf_models/tools/attention.py +188 -0
- ldf_models/tools/t5.py +564 -0
- ldf_models/tools/tokenizers.py +84 -0
- ldf_models/tools/wan_model.py +592 -0
- ldf_models/tools/wan_vae_1d.py +762 -0
- ldf_models/vae_wan_1d.py +212 -0
- ldf_utils/__init__.py +0 -0
- ldf_utils/initialize.py +286 -0
- ldf_utils/math/__init__.py +0 -0
- ldf_utils/math/quaternion.py +447 -0
- ldf_utils/motion_process.py +365 -0
- model.safetensors +3 -0
- requirements.txt +19 -0
- vae.safetensors +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python cache
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
|
| 8 |
+
# Virtual environments
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
| 11 |
+
ENV/
|
| 12 |
+
.venv
|
| 13 |
+
|
| 14 |
+
# PyTorch/Model cache
|
| 15 |
+
*.pth~
|
| 16 |
+
*.safetensors~
|
| 17 |
+
checkpoint/
|
| 18 |
+
checkpoints/
|
| 19 |
+
|
| 20 |
+
# Hugging Face cache
|
| 21 |
+
.cache/
|
| 22 |
+
huggingface_cache/
|
| 23 |
+
|
| 24 |
+
# Generated outputs
|
| 25 |
+
outputs/
|
| 26 |
+
generated_motions/
|
| 27 |
+
*.npy
|
| 28 |
+
*.pkl
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
*~
|
| 36 |
+
|
| 37 |
+
# OS
|
| 38 |
+
.DS_Store
|
| 39 |
+
Thumbs.db
|
| 40 |
+
|
| 41 |
+
# Jupyter
|
| 42 |
+
.ipynb_checkpoints/
|
| 43 |
+
*.ipynb
|
| 44 |
+
|
| 45 |
+
# Logs
|
| 46 |
+
*.log
|
| 47 |
+
logs/
|
| 48 |
+
wandb/
|
| 49 |
+
|
| 50 |
+
# Test outputs
|
| 51 |
+
test_output/
|
| 52 |
+
test_results/
|
| 53 |
+
tmp/
|
| 54 |
+
|
README.md
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- text-to-motion
|
| 5 |
+
- motion-generation
|
| 6 |
+
- diffusion-forcing
|
| 7 |
+
- humanml3d
|
| 8 |
+
- computer-animation
|
| 9 |
+
library_name: transformers
|
| 10 |
+
pipeline_tag: other
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
|
| 17 |
+
**A TINY version of the original FloodDiffusion**
|
| 18 |
+
|
| 19 |
+
[Paper](https://arxiv.org/abs/2512.03520) | [Github](https://github.com/ShandaAI/FloodDiffusion) | [Project Page](https://shandaai.github.io/FloodDiffusion/)
|
| 20 |
+
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
## Installation
|
| 24 |
+
|
| 25 |
+
### Prerequisites
|
| 26 |
+
|
| 27 |
+
- Python 3.8+
|
| 28 |
+
- CUDA-capable GPU with 16GB+ VRAM (recommended)
|
| 29 |
+
- 16GB+ system RAM
|
| 30 |
+
|
| 31 |
+
### Dependencies
|
| 32 |
+
|
| 33 |
+
**Step 1: Install basic dependencies**
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
pip install torch transformers huggingface_hub
|
| 37 |
+
pip install lightning diffusers omegaconf ftfy numpy
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**Step 2: Install Flash Attention (Required)**
|
| 41 |
+
|
| 42 |
+
Flash attention requires CUDA and may need compilation. Choose the appropriate method:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install flash-attn --no-build-isolation
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**Note:** Flash attention is **required** for this model. If installation fails, please refer to the [official flash-attention installation guide](https://github.com/Dao-AILab/flash-attention#installation-and-features).
|
| 49 |
+
|
| 50 |
+
## Quick Start
|
| 51 |
+
|
| 52 |
+
### Basic Usage
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
from transformers import AutoModel
|
| 56 |
+
|
| 57 |
+
# Load model
|
| 58 |
+
model = AutoModel.from_pretrained(
|
| 59 |
+
"ShandaAI/FloodDiffusionTiny",
|
| 60 |
+
trust_remote_code=True
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Generate motion from text (263-dim HumanML3D features)
|
| 64 |
+
motion = model("a person walking forward", length=60)
|
| 65 |
+
print(f"Generated motion: {motion.shape}") # (~240, 263)
|
| 66 |
+
|
| 67 |
+
# Generate motion as joint coordinates (22 joints × 3 coords) with ema (alpha: 0.0-1.0)
|
| 68 |
+
motion_joints = model("a person walking forward", length=60, output_joints=True, smoothing_alpha=0.5)
|
| 69 |
+
print(f"Generated joints: {motion_joints.shape}") # (~240, 22, 3)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Batch Generation
|
| 73 |
+
|
| 74 |
+
```python
|
| 75 |
+
# Generate multiple motions efficiently
|
| 76 |
+
texts = [
|
| 77 |
+
"a person walking forward",
|
| 78 |
+
"a person running quickly",
|
| 79 |
+
"a person jumping up and down"
|
| 80 |
+
]
|
| 81 |
+
lengths = [60, 50, 40] # Different lengths for each motion
|
| 82 |
+
|
| 83 |
+
motions = model(texts, length=lengths)
|
| 84 |
+
|
| 85 |
+
for i, motion in enumerate(motions):
|
| 86 |
+
print(f"Motion {i}: {motion.shape}")
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### Multi-Text Motion Transitions
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
# Generate a motion sequence with smooth transitions between actions
|
| 93 |
+
motion = model(
|
| 94 |
+
text=[["walk forward", "turn around", "run back"]],
|
| 95 |
+
length=[120],
|
| 96 |
+
text_end=[[40, 80, 120]] # Transition points in latent tokens
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Output: ~480 frames showing all three actions smoothly connected
|
| 100 |
+
print(f"Transition motion: {motion[0].shape}")
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## API Reference
|
| 104 |
+
|
| 105 |
+
### `model(text, length=60, text_end=None, num_denoise_steps=None, output_joints=False, smoothing_alpha=1.0)`
|
| 106 |
+
|
| 107 |
+
Generate motion sequences from text descriptions.
|
| 108 |
+
|
| 109 |
+
**Parameters:**
|
| 110 |
+
|
| 111 |
+
- **text** (`str`, `List[str]`, or `List[List[str]]`): Text description(s)
|
| 112 |
+
- Single string: Generate one motion
|
| 113 |
+
- List of strings: Batch generation
|
| 114 |
+
- Nested list: Multiple text prompts per motion (for transitions)
|
| 115 |
+
|
| 116 |
+
- **length** (`int` or `List[int]`, default=60): Number of latent tokens to generate
|
| 117 |
+
- Output frames ≈ `length × 4` (due to VAE upsampling)
|
| 118 |
+
- Example: `length=60` → ~240 frames (~12 seconds at 20 FPS)
|
| 119 |
+
|
| 120 |
+
- **text_end** (`List[int]` or `List[List[int]]`, optional): Latent token positions for text transitions
|
| 121 |
+
- Only used when `text` is a nested list
|
| 122 |
+
- Specifies when to switch between different text descriptions
|
| 123 |
+
- **IMPORTANT**: Must have the same length as the corresponding text list
|
| 124 |
+
- Example: `text=[["walk", "turn", "sit"]]` requires `text_end=[[20, 40, 60]]` (3 endpoints for 3 texts)
|
| 125 |
+
- Must be in ascending order
|
| 126 |
+
|
| 127 |
+
- **num_denoise_steps** (`int`, optional): Number of denoising iterations
|
| 128 |
+
- Higher values produce better quality but slower generation
|
| 129 |
+
- Recommended range: 10-50
|
| 130 |
+
|
| 131 |
+
- **output_joints** (`bool`, default=False): Output format selector
|
| 132 |
+
- `False`: Returns 263-dimensional HumanML3D features
|
| 133 |
+
- `True`: Returns 22×3 joint coordinates for direct visualization
|
| 134 |
+
|
| 135 |
+
- **smoothing_alpha** (`float`, default=1.0): EMA smoothing factor for joint positions (only used when `output_joints=True`)
|
| 136 |
+
- `1.0`: No smoothing (default)
|
| 137 |
+
- `0.5`: Medium smoothing (recommended for smoother animations)
|
| 138 |
+
- `0.0`: Maximum smoothing
|
| 139 |
+
- Range: 0.0 to 1.0
|
| 140 |
+
|
| 141 |
+
**Returns:**
|
| 142 |
+
- Single motion:
|
| 143 |
+
- `output_joints=False`: `numpy.ndarray` of shape `(frames, 263)`
|
| 144 |
+
- `output_joints=True`: `numpy.ndarray` of shape `(frames, 22, 3)`
|
| 145 |
+
- Batch: `List[numpy.ndarray]` with shapes as above
|
| 146 |
+
|
| 147 |
+
**Example:**
|
| 148 |
+
```python
|
| 149 |
+
# Single generation (263-dim features)
|
| 150 |
+
motion = model("walk forward", length=60) # Returns (240, 263)
|
| 151 |
+
|
| 152 |
+
# Single generation (joint coordinates)
|
| 153 |
+
joints = model("walk forward", length=60, output_joints=True) # Returns (240, 22, 3)
|
| 154 |
+
|
| 155 |
+
# Batch generation
|
| 156 |
+
motions = model(["walk", "run"], length=[60, 50]) # Returns list of 2 arrays
|
| 157 |
+
|
| 158 |
+
# Multi-text transitions
|
| 159 |
+
motion = model(
|
| 160 |
+
[["walk", "turn"]],
|
| 161 |
+
length=[60],
|
| 162 |
+
text_end=[[30, 60]]
|
| 163 |
+
) # Returns list with 1 array of shape (240, 263)
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## Citation
|
| 167 |
+
|
| 168 |
+
If you use this model in your research, please cite:
|
| 169 |
+
|
| 170 |
+
```bibtex
|
| 171 |
+
@article{cai2025flooddiffusion,
|
| 172 |
+
title={FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation},
|
| 173 |
+
author={Yiyi Cai, Yuhan Wu, Kunhang Li, You Zhou, Bo Zheng, Haiyang Liu},
|
| 174 |
+
journal={arXiv preprint arXiv:2512.03520},
|
| 175 |
+
year={2025}
|
| 176 |
+
}
|
| 177 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FloodDiffusion - Text-to-Motion Generation
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from transformers import AutoModel
|
| 6 |
+
|
| 7 |
+
model = AutoModel.from_pretrained("your-username/FloodDiffusion", trust_remote_code=True)
|
| 8 |
+
motion = model("a person walking forward", length=60)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__version__ = "1.0.0"
|
config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": ["LDFModel"],
|
| 3 |
+
"model_type": "ldf_motion",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "hf_pipeline.LDFModel",
|
| 6 |
+
"AutoConfig": "hf_pipeline.LDFConfig"
|
| 7 |
+
},
|
| 8 |
+
"torch_dtype": "float32",
|
| 9 |
+
"transformers_version": "4.30.0",
|
| 10 |
+
"license": "mit"
|
| 11 |
+
}
|
generate_ldf.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from lightning import seed_everything
|
| 6 |
+
from safetensors.torch import load_file as load_safetensors
|
| 7 |
+
|
| 8 |
+
from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
|
| 9 |
+
|
| 10 |
+
# Set tokenizers parallelism to false to avoid warnings in multiprocessing
|
| 11 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model_from_config():
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
torch.set_float32_matmul_precision("high")
|
| 17 |
+
cfg = load_config()
|
| 18 |
+
seed_everything(cfg.seed)
|
| 19 |
+
|
| 20 |
+
# Get the directory containing the config file
|
| 21 |
+
# Try to find config directory from sys.argv or use current directory
|
| 22 |
+
if '--config' in sys.argv:
|
| 23 |
+
config_idx = sys.argv.index('--config') + 1
|
| 24 |
+
config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
|
| 25 |
+
else:
|
| 26 |
+
config_dir = os.getcwd()
|
| 27 |
+
|
| 28 |
+
vae = instantiate(
|
| 29 |
+
target=cfg.test_vae.target,
|
| 30 |
+
cfg=None,
|
| 31 |
+
hfstyle=False,
|
| 32 |
+
**cfg.test_vae.params,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Handle relative paths
|
| 36 |
+
vae_path = cfg.test_vae_ckpt
|
| 37 |
+
if not os.path.isabs(vae_path):
|
| 38 |
+
vae_path = os.path.join(config_dir, vae_path)
|
| 39 |
+
|
| 40 |
+
# Load from safetensors (already contains EMA weights)
|
| 41 |
+
vae_state_dict = load_safetensors(vae_path)
|
| 42 |
+
vae.load_state_dict(vae_state_dict, strict=True)
|
| 43 |
+
print(f"Loaded VAE model from {vae_path}")
|
| 44 |
+
|
| 45 |
+
compare_statedict_and_parameters(
|
| 46 |
+
state_dict=vae.state_dict(),
|
| 47 |
+
named_parameters=vae.named_parameters(),
|
| 48 |
+
named_buffers=vae.named_buffers(),
|
| 49 |
+
)
|
| 50 |
+
vae.to(device)
|
| 51 |
+
vae.eval()
|
| 52 |
+
|
| 53 |
+
# Model - fix relative paths in model params
|
| 54 |
+
model_params = dict(cfg.model.params)
|
| 55 |
+
# Convert relative paths to absolute paths
|
| 56 |
+
if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
|
| 57 |
+
if not os.path.isabs(model_params['checkpoint_path']):
|
| 58 |
+
model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
|
| 59 |
+
if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
|
| 60 |
+
if not os.path.isabs(model_params['tokenizer_path']):
|
| 61 |
+
model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
|
| 62 |
+
|
| 63 |
+
model = instantiate(
|
| 64 |
+
target=cfg.model.target, cfg=None, hfstyle=False, **model_params
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Handle relative paths
|
| 68 |
+
model_path = cfg.test_ckpt
|
| 69 |
+
if not os.path.isabs(model_path):
|
| 70 |
+
model_path = os.path.join(config_dir, model_path)
|
| 71 |
+
|
| 72 |
+
# Load from safetensors (already contains EMA weights)
|
| 73 |
+
model_state_dict = load_safetensors(model_path)
|
| 74 |
+
model.load_state_dict(model_state_dict, strict=True)
|
| 75 |
+
print(f"Loaded model from {model_path}")
|
| 76 |
+
|
| 77 |
+
compare_statedict_and_parameters(
|
| 78 |
+
state_dict=model.state_dict(),
|
| 79 |
+
named_parameters=model.named_parameters(),
|
| 80 |
+
named_buffers=model.named_buffers(),
|
| 81 |
+
)
|
| 82 |
+
model.to(device)
|
| 83 |
+
model.eval()
|
| 84 |
+
|
| 85 |
+
return vae, model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.inference_mode()
|
| 89 |
+
def generate_feature_stream(
|
| 90 |
+
model, feature_length, text, feature_text_end=None, num_denoise_steps=None
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Streaming interface for feature generation
|
| 94 |
+
Args:
|
| 95 |
+
model: Loaded model
|
| 96 |
+
feature_length: List[int], generation length for each sample
|
| 97 |
+
text: List[str] or List[List[str]], text prompts
|
| 98 |
+
feature_text_end: List[List[int]], time points where text ends (if text is list of list)
|
| 99 |
+
num_denoise_steps: Number of denoising steps
|
| 100 |
+
Yields:
|
| 101 |
+
dict: Contains "generated" (current generated feature segment)
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
# Construct input dict x
|
| 105 |
+
# stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
|
| 106 |
+
x = {"feature_length": torch.tensor(feature_length), "text": text}
|
| 107 |
+
|
| 108 |
+
if feature_text_end is not None:
|
| 109 |
+
x["feature_text_end"] = feature_text_end
|
| 110 |
+
|
| 111 |
+
# Call model's stream_generate
|
| 112 |
+
# Note: stream_generate is a generator
|
| 113 |
+
generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
|
| 114 |
+
|
| 115 |
+
for step_output in generator:
|
| 116 |
+
# step_output is already a dict with "generated" key
|
| 117 |
+
yield step_output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
import argparse
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser()
|
| 124 |
+
parser.add_argument("--config", type=str, required=True, help="Path to config")
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--text", type=str, default="a person walks forward", help="Text prompt"
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument("--length", type=int, default=120, help="Motion length")
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--output", type=str, default="output.mp4", help="Output video path"
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
|
| 134 |
+
)
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
|
| 137 |
+
print("Loading model...")
|
| 138 |
+
vae, model = load_model_from_config()
|
| 139 |
+
|
hf_pipeline.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LDF Model for Hugging Face Hub
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from transformers import AutoModel
|
| 6 |
+
|
| 7 |
+
model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
|
| 8 |
+
motion = model("a person walking forward", length=60)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 13 |
+
from typing import Union, List, Optional
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LDFConfig(PretrainedConfig):
|
| 19 |
+
"""Configuration for LDF Motion Generation Model"""
|
| 20 |
+
model_type = "ldf_motion"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
input_dim=4,
|
| 25 |
+
output_dim=263,
|
| 26 |
+
**kwargs
|
| 27 |
+
):
|
| 28 |
+
super().__init__(**kwargs)
|
| 29 |
+
self.input_dim = input_dim
|
| 30 |
+
self.output_dim = output_dim
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LDFModel(PreTrainedModel):
|
| 34 |
+
"""
|
| 35 |
+
LDF Motion Generation Model
|
| 36 |
+
|
| 37 |
+
This model generates motion sequences from text descriptions using Latent Diffusion Forcing.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
>>> from transformers import AutoModel
|
| 41 |
+
>>> model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
|
| 42 |
+
>>> motion = model("a person walking forward", length=60)
|
| 43 |
+
>>> print(motion.shape) # (~240, 263)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
config_class = LDFConfig
|
| 47 |
+
|
| 48 |
+
def __init__(self, config):
|
| 49 |
+
super().__init__(config)
|
| 50 |
+
self.config = config
|
| 51 |
+
|
| 52 |
+
# Will be loaded in from_pretrained
|
| 53 |
+
self.ldf_model = None
|
| 54 |
+
self.vae = None
|
| 55 |
+
self.model_dir = None # Store model directory for later use
|
| 56 |
+
|
| 57 |
+
def _load_models(self):
|
| 58 |
+
"""Load the actual LDF and VAE models"""
|
| 59 |
+
if self.ldf_model is not None:
|
| 60 |
+
return # Already loaded
|
| 61 |
+
|
| 62 |
+
# Get the model directory - should be set by from_pretrained
|
| 63 |
+
if hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path):
|
| 64 |
+
model_dir = self.name_or_path
|
| 65 |
+
else:
|
| 66 |
+
raise RuntimeError(
|
| 67 |
+
"Model directory not found. Please use from_pretrained() to load the model."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Save model_dir for later use (e.g., in output_joints conversion)
|
| 71 |
+
self.model_dir = model_dir
|
| 72 |
+
|
| 73 |
+
# Add model_dir to sys.path for imports
|
| 74 |
+
if model_dir not in sys.path:
|
| 75 |
+
sys.path.insert(0, model_dir)
|
| 76 |
+
|
| 77 |
+
# Use dynamic import to avoid HF's static import checker
|
| 78 |
+
import importlib
|
| 79 |
+
generate_ldf = importlib.import_module('generate_ldf')
|
| 80 |
+
load_model_from_config = generate_ldf.load_model_from_config
|
| 81 |
+
|
| 82 |
+
config_path = os.path.join(model_dir, "ldf.yaml")
|
| 83 |
+
old_argv = sys.argv
|
| 84 |
+
sys.argv = ['model', '--config', config_path]
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
self.vae, self.ldf_model = load_model_from_config()
|
| 88 |
+
|
| 89 |
+
# Move to correct device
|
| 90 |
+
device = next(self.parameters()).device if list(self.parameters()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
+
self.ldf_model = self.ldf_model.to(device)
|
| 92 |
+
self.vae = self.vae.to(device)
|
| 93 |
+
finally:
|
| 94 |
+
sys.argv = old_argv
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 98 |
+
"""
|
| 99 |
+
Load pretrained model
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
pretrained_model_name_or_path: Model name or path
|
| 103 |
+
trust_remote_code: Must be True to load this custom model
|
| 104 |
+
**kwargs: Additional arguments
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
LDFModel instance
|
| 108 |
+
"""
|
| 109 |
+
# Check trust_remote_code
|
| 110 |
+
if not kwargs.get('trust_remote_code', False):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Loading this model requires trust_remote_code=True. "
|
| 113 |
+
"Usage: AutoModel.from_pretrained(..., trust_remote_code=True)"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Download if needed
|
| 117 |
+
if not os.path.exists(pretrained_model_name_or_path):
|
| 118 |
+
from huggingface_hub import snapshot_download
|
| 119 |
+
model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
|
| 120 |
+
else:
|
| 121 |
+
model_path = pretrained_model_name_or_path
|
| 122 |
+
|
| 123 |
+
# Load config
|
| 124 |
+
config = LDFConfig.from_pretrained(model_path)
|
| 125 |
+
|
| 126 |
+
# Create model
|
| 127 |
+
model = cls(config)
|
| 128 |
+
model.name_or_path = model_path
|
| 129 |
+
|
| 130 |
+
# Load the actual models
|
| 131 |
+
model._load_models()
|
| 132 |
+
|
| 133 |
+
return model
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
text: Union[str, List[str], List[List[str]]],
|
| 138 |
+
length: Union[int, List[int]] = 60,
|
| 139 |
+
text_end: Optional[Union[List[int], List[List[int]]]] = None,
|
| 140 |
+
num_denoise_steps: Optional[int] = None,
|
| 141 |
+
**kwargs
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Generate motion from text
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
text: Text description(s)
|
| 148 |
+
length: Number of latent tokens (output frames ≈ length × 4)
|
| 149 |
+
text_end: Transition points for multi-text
|
| 150 |
+
num_denoise_steps: Number of denoising steps
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Generated motion sequence(s)
|
| 154 |
+
"""
|
| 155 |
+
return self.__call__(text, length, text_end, num_denoise_steps)
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def __call__(
|
| 159 |
+
self,
|
| 160 |
+
text: Union[str, List[str], List[List[str]]],
|
| 161 |
+
length: Union[int, List[int]] = 60,
|
| 162 |
+
text_end: Optional[Union[List[int], List[List[int]]]] = None,
|
| 163 |
+
num_denoise_steps: Optional[int] = None,
|
| 164 |
+
output_joints: bool = False,
|
| 165 |
+
smoothing_alpha: float = 1.0
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
Generate motion sequences
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
text: Text description
|
| 172 |
+
- Single string: "walk" -> single sample
|
| 173 |
+
- String list: ["walk", "run"] -> batch
|
| 174 |
+
- Nested list: [["walk", "turn"], ["run", "jump"]] -> multi-text per sample
|
| 175 |
+
length: Number of latent tokens (frames ≈ length × 4)
|
| 176 |
+
text_end: Token positions for text switching
|
| 177 |
+
num_denoise_steps: Number of denoising steps
|
| 178 |
+
output_joints: If True, output 22×3 joint coordinates; if False (default), output 263-dim HumanML3D features
|
| 179 |
+
smoothing_alpha: EMA smoothing factor for joint positions (0.0-1.0, default=1.0 no smoothing)
|
| 180 |
+
- Only used when output_joints=True
|
| 181 |
+
- Recommended: 0.5 for smoother animations
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
numpy.ndarray or list of arrays
|
| 185 |
+
- If output_joints=False: shape (frames, 263)
|
| 186 |
+
- If output_joints=True: shape (frames, 22, 3)
|
| 187 |
+
"""
|
| 188 |
+
# Ensure models are loaded
|
| 189 |
+
self._load_models()
|
| 190 |
+
|
| 191 |
+
# Normalize inputs
|
| 192 |
+
is_single = not isinstance(length, list)
|
| 193 |
+
if is_single:
|
| 194 |
+
text_batch = [text]
|
| 195 |
+
length_batch = [length]
|
| 196 |
+
text_end_batch = [text_end] if text_end is not None else None
|
| 197 |
+
else:
|
| 198 |
+
text_batch = text
|
| 199 |
+
length_batch = length
|
| 200 |
+
text_end_batch = text_end
|
| 201 |
+
|
| 202 |
+
# Validate text_end alignment with text
|
| 203 |
+
if text_end_batch is not None:
|
| 204 |
+
for i, (txt, te) in enumerate(zip(text_batch, text_end_batch)):
|
| 205 |
+
if isinstance(txt, list) and te is not None:
|
| 206 |
+
if len(txt) != len(te):
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Batch {i}: text has {len(txt)} segments but text_end has {len(te)} endpoints. "
|
| 209 |
+
f"They must match! text={txt}, text_end={te}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
batch_size = len(text_batch)
|
| 213 |
+
|
| 214 |
+
# Construct input dict for model
|
| 215 |
+
x = {"feature_length": torch.tensor(length_batch), "text": text_batch}
|
| 216 |
+
if text_end_batch is not None:
|
| 217 |
+
x["feature_text_end"] = text_end_batch
|
| 218 |
+
|
| 219 |
+
# Non-streaming generate (following generate_ldf.py 125-139)
|
| 220 |
+
output = self.ldf_model.generate(x, num_denoise_steps=num_denoise_steps)
|
| 221 |
+
generated_batch = output["generated"]
|
| 222 |
+
|
| 223 |
+
# Decode with VAE and optionally convert to joints
|
| 224 |
+
decoded_results = []
|
| 225 |
+
joints_results = [] if output_joints else None
|
| 226 |
+
|
| 227 |
+
# Import motion processing module once if needed
|
| 228 |
+
if output_joints:
|
| 229 |
+
import importlib.util
|
| 230 |
+
import numpy as np
|
| 231 |
+
utils_spec = importlib.util.spec_from_file_location(
|
| 232 |
+
"motion_process",
|
| 233 |
+
os.path.join(self.model_dir, "ldf_utils", "motion_process.py")
|
| 234 |
+
)
|
| 235 |
+
motion_process_module = importlib.util.module_from_spec(utils_spec)
|
| 236 |
+
utils_spec.loader.exec_module(motion_process_module)
|
| 237 |
+
|
| 238 |
+
for i, generated in enumerate(generated_batch):
|
| 239 |
+
if generated is not None and torch.is_tensor(generated):
|
| 240 |
+
# Decode with VAE (following generate_ldf.py line 130)
|
| 241 |
+
decoded_g = self.vae.decode(generated[None, :])[0]
|
| 242 |
+
|
| 243 |
+
if output_joints:
|
| 244 |
+
# Convert to joints using StreamJointRecovery263 with smoothing
|
| 245 |
+
# Create a new recovery instance for each sample to maintain independent state
|
| 246 |
+
decoded_np = decoded_g.cpu().numpy()
|
| 247 |
+
recovery = motion_process_module.StreamJointRecovery263(
|
| 248 |
+
joints_num=22, smoothing_alpha=smoothing_alpha
|
| 249 |
+
)
|
| 250 |
+
joints = [recovery.process_frame(frame) for frame in decoded_np]
|
| 251 |
+
joints = np.array(joints)
|
| 252 |
+
joints_results.append(joints)
|
| 253 |
+
else:
|
| 254 |
+
decoded_results.append(decoded_g.cpu().numpy())
|
| 255 |
+
else:
|
| 256 |
+
if output_joints:
|
| 257 |
+
joints_results.append(None)
|
| 258 |
+
else:
|
| 259 |
+
decoded_results.append(None)
|
| 260 |
+
|
| 261 |
+
# Return results
|
| 262 |
+
if output_joints:
|
| 263 |
+
return joints_results[0] if is_single else joints_results
|
| 264 |
+
else:
|
| 265 |
+
return decoded_results[0] if is_single else decoded_results
|
| 266 |
+
|
| 267 |
+
def generate(self, *args, **kwargs):
|
| 268 |
+
"""Alias for __call__ to match transformers API"""
|
| 269 |
+
return self.__call__(*args, **kwargs)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# For backwards compatibility
|
| 273 |
+
LDFPipeline = LDFModel
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Register with AutoModel
|
| 277 |
+
try:
|
| 278 |
+
from transformers import AutoModel, AutoConfig
|
| 279 |
+
AutoConfig.register("ldf_motion", LDFConfig)
|
| 280 |
+
AutoModel.register(LDFConfig, LDFModel)
|
| 281 |
+
except:
|
| 282 |
+
pass
|
ldf.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_name: ldf
|
| 2 |
+
seed: 1234
|
| 3 |
+
debug: false
|
| 4 |
+
train: false
|
| 5 |
+
|
| 6 |
+
save_dir: ./outputs
|
| 7 |
+
resume_ckpt: null
|
| 8 |
+
test_ckpt: "model.safetensors"
|
| 9 |
+
test_vae_ckpt: "vae.safetensors"
|
| 10 |
+
|
| 11 |
+
test_vae:
|
| 12 |
+
target: ldf_models.vae_wan_1d.VAEWanModel
|
| 13 |
+
ema_decay: 0.99
|
| 14 |
+
params:
|
| 15 |
+
input_dim: 263
|
| 16 |
+
z_dim: 4
|
| 17 |
+
|
| 18 |
+
test_setting:
|
| 19 |
+
render: false
|
| 20 |
+
simple: true
|
| 21 |
+
recover_dim: 263
|
| 22 |
+
|
| 23 |
+
val_repeat: 1
|
| 24 |
+
|
| 25 |
+
model:
|
| 26 |
+
target: ldf_models.diffusion_forcing_wan_tiny.DiffForcingWanModel
|
| 27 |
+
ema_decay: 0.99
|
| 28 |
+
params:
|
| 29 |
+
model_name: "google/umt5-base"
|
| 30 |
+
input_dim: 4
|
| 31 |
+
noise_steps: 10
|
| 32 |
+
hidden_dim: 256
|
| 33 |
+
ffn_dim: 1024
|
| 34 |
+
freq_dim: 64
|
| 35 |
+
num_heads: 8
|
| 36 |
+
num_layers: 8
|
| 37 |
+
time_embedding_scale: 1.0
|
| 38 |
+
chunk_size: 5
|
| 39 |
+
use_text_cond: True
|
| 40 |
+
text_len: 128
|
| 41 |
+
drop_out: 0.1
|
| 42 |
+
cfg_scale: 5.0
|
| 43 |
+
prediction_type: "vel"
|
| 44 |
+
causal: False
|
ldf_models/__init__.py
ADDED
|
File without changes
|
ldf_models/diffusion_forcing_wan_tiny.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import AutoTokenizer, AutoModel
|
| 8 |
+
|
| 9 |
+
from .tools.wan_model import WanModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HFT5Encoder:
|
| 13 |
+
"""Wrapper for HuggingFace T5 encoder, compatible with original T5EncoderModel interface"""
|
| 14 |
+
def __init__(self, text_len, dtype=torch.float32, device=torch.device("cpu"), model_name="google/umt5-base"):
|
| 15 |
+
self.text_len = text_len
|
| 16 |
+
self.dtype = dtype
|
| 17 |
+
self.device = device
|
| 18 |
+
|
| 19 |
+
print(f"Loading {model_name} from HuggingFace...")
|
| 20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 21 |
+
self.model = AutoModel.from_pretrained(
|
| 22 |
+
model_name,
|
| 23 |
+
dtype=dtype
|
| 24 |
+
).encoder # Only use the encoder part
|
| 25 |
+
self.model.eval()
|
| 26 |
+
self.model.requires_grad_(False)
|
| 27 |
+
self.model.to(device)
|
| 28 |
+
|
| 29 |
+
def __call__(self, texts, device):
|
| 30 |
+
"""Encode texts, returns list of tensors (one per text, with padding removed)"""
|
| 31 |
+
# Tokenize
|
| 32 |
+
inputs = self.tokenizer(
|
| 33 |
+
texts,
|
| 34 |
+
padding=True,
|
| 35 |
+
truncation=True,
|
| 36 |
+
max_length=self.text_len,
|
| 37 |
+
return_tensors="pt"
|
| 38 |
+
)
|
| 39 |
+
ids = inputs.input_ids.to(device)
|
| 40 |
+
mask = inputs.attention_mask.to(device)
|
| 41 |
+
|
| 42 |
+
# Encode (model should already be on device via external .model.to(device) call)
|
| 43 |
+
context = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| 44 |
+
|
| 45 |
+
# Get sequence lengths (excluding padding)
|
| 46 |
+
seq_lens = mask.sum(dim=1).long()
|
| 47 |
+
|
| 48 |
+
# Return list of tensors with padding removed (same as original T5EncoderModel)
|
| 49 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DiffForcingWanModel(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
model_name="google/umt5-base", # HuggingFace model name
|
| 56 |
+
input_dim=256,
|
| 57 |
+
hidden_dim=1024,
|
| 58 |
+
ffn_dim=2048,
|
| 59 |
+
freq_dim=256,
|
| 60 |
+
num_heads=8,
|
| 61 |
+
num_layers=8,
|
| 62 |
+
time_embedding_scale=1.0,
|
| 63 |
+
chunk_size=5,
|
| 64 |
+
noise_steps=10,
|
| 65 |
+
use_text_cond=True,
|
| 66 |
+
text_len=512,
|
| 67 |
+
drop_out=0.1,
|
| 68 |
+
cfg_scale=5.0,
|
| 69 |
+
prediction_type="vel", # "vel", "x0", "noise"
|
| 70 |
+
causal=False,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
self.input_dim = input_dim
|
| 75 |
+
self.hidden_dim = hidden_dim
|
| 76 |
+
self.ffn_dim = ffn_dim
|
| 77 |
+
self.freq_dim = freq_dim
|
| 78 |
+
self.num_heads = num_heads
|
| 79 |
+
self.num_layers = num_layers
|
| 80 |
+
self.time_embedding_scale = time_embedding_scale
|
| 81 |
+
self.chunk_size = chunk_size
|
| 82 |
+
self.noise_steps = noise_steps
|
| 83 |
+
self.use_text_cond = use_text_cond
|
| 84 |
+
self.drop_out = drop_out
|
| 85 |
+
self.cfg_scale = cfg_scale
|
| 86 |
+
self.prediction_type = prediction_type
|
| 87 |
+
self.causal = causal
|
| 88 |
+
|
| 89 |
+
self.text_dim = 768 # umt5-base hidden size
|
| 90 |
+
self.text_len = text_len
|
| 91 |
+
self.model_name = model_name
|
| 92 |
+
|
| 93 |
+
# Load model and tokenizer from HuggingFace
|
| 94 |
+
print(f"Loading {model_name} from HuggingFace...")
|
| 95 |
+
self.text_encoder = HFT5Encoder(
|
| 96 |
+
text_len=text_len,
|
| 97 |
+
dtype=torch.bfloat16,
|
| 98 |
+
device=torch.device("cpu"),
|
| 99 |
+
model_name=model_name,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Text encoding cache
|
| 103 |
+
self.text_cache = {}
|
| 104 |
+
self.model = WanModel(
|
| 105 |
+
model_type="t2v",
|
| 106 |
+
patch_size=(1, 1, 1),
|
| 107 |
+
text_len=self.text_len,
|
| 108 |
+
in_dim=self.input_dim,
|
| 109 |
+
dim=self.hidden_dim,
|
| 110 |
+
ffn_dim=self.ffn_dim,
|
| 111 |
+
freq_dim=self.freq_dim,
|
| 112 |
+
text_dim=self.text_dim,
|
| 113 |
+
out_dim=self.input_dim,
|
| 114 |
+
num_heads=self.num_heads,
|
| 115 |
+
num_layers=self.num_layers,
|
| 116 |
+
window_size=(-1, -1),
|
| 117 |
+
qk_norm=True,
|
| 118 |
+
cross_attn_norm=True,
|
| 119 |
+
eps=1e-6,
|
| 120 |
+
causal=self.causal,
|
| 121 |
+
)
|
| 122 |
+
self.param_dtype = torch.float32
|
| 123 |
+
|
| 124 |
+
def encode_text_with_cache(self, text_list, device):
|
| 125 |
+
"""Encode text using cache
|
| 126 |
+
Args:
|
| 127 |
+
text_list: List[str], list of texts
|
| 128 |
+
device: torch.device
|
| 129 |
+
Returns:
|
| 130 |
+
List[Tensor]: List of encoded text features
|
| 131 |
+
"""
|
| 132 |
+
text_features = []
|
| 133 |
+
indices_to_encode = []
|
| 134 |
+
texts_to_encode = []
|
| 135 |
+
|
| 136 |
+
# Check cache
|
| 137 |
+
for i, text in enumerate(text_list):
|
| 138 |
+
if text in self.text_cache:
|
| 139 |
+
# Get from cache and move to correct device
|
| 140 |
+
cached_feature = self.text_cache[text].to(device)
|
| 141 |
+
text_features.append(cached_feature)
|
| 142 |
+
else:
|
| 143 |
+
# Need to encode
|
| 144 |
+
text_features.append(None)
|
| 145 |
+
indices_to_encode.append(i)
|
| 146 |
+
texts_to_encode.append(text)
|
| 147 |
+
|
| 148 |
+
# Batch encode uncached texts
|
| 149 |
+
if texts_to_encode:
|
| 150 |
+
self.text_encoder.model.to(device)
|
| 151 |
+
encoded = self.text_encoder(texts_to_encode, device)
|
| 152 |
+
|
| 153 |
+
# Store in cache and update results
|
| 154 |
+
for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded):
|
| 155 |
+
# Cache to CPU to save GPU memory
|
| 156 |
+
self.text_cache[text] = feature.cpu()
|
| 157 |
+
text_features[idx] = feature
|
| 158 |
+
|
| 159 |
+
return text_features
|
| 160 |
+
|
| 161 |
+
def preprocess(self, x):
|
| 162 |
+
# (bs, T, C) -> (bs, C, T, 1, 1)
|
| 163 |
+
x = x.permute(0, 2, 1)[:, :, :, None, None]
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
def postprocess(self, x):
|
| 167 |
+
# (bs, C, T, 1, 1) -> (bs, T, C)
|
| 168 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1)
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
def _get_noise_levels(self, device, seq_len, time_steps):
|
| 172 |
+
"""Get noise levels"""
|
| 173 |
+
# noise_level[i] = clip(1 + i / chunk_size - time_steps, 0, 1)
|
| 174 |
+
noise_level = torch.clamp(
|
| 175 |
+
1
|
| 176 |
+
+ torch.arange(seq_len, device=device) / self.chunk_size
|
| 177 |
+
- time_steps.unsqueeze(1),
|
| 178 |
+
min=0.0,
|
| 179 |
+
max=1.0,
|
| 180 |
+
)
|
| 181 |
+
return noise_level
|
| 182 |
+
|
| 183 |
+
def add_noise(self, x, noise_level):
|
| 184 |
+
"""Add noise
|
| 185 |
+
Args:
|
| 186 |
+
x: (B, T, D)
|
| 187 |
+
noise_level: (B, T)
|
| 188 |
+
"""
|
| 189 |
+
noise = torch.randn_like(x)
|
| 190 |
+
# noise_level: (B, T) -> (B, T, 1)
|
| 191 |
+
noise_level = noise_level.unsqueeze(-1)
|
| 192 |
+
noisy_x = x * (1 - noise_level) + noise_level * noise
|
| 193 |
+
return noisy_x, noise
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
feature = x["feature"] # (B, T, C)
|
| 197 |
+
feature_length = x["feature_length"] # (B,)
|
| 198 |
+
batch_size, seq_len, _ = feature.shape
|
| 199 |
+
device = feature.device
|
| 200 |
+
|
| 201 |
+
# Randomly use a time step
|
| 202 |
+
time_steps = []
|
| 203 |
+
for i in range(batch_size):
|
| 204 |
+
valid_len = feature_length[i].item()
|
| 205 |
+
# Random float from 0 to valid_len/chunk_size, not an integer
|
| 206 |
+
max_time = valid_len / self.chunk_size
|
| 207 |
+
# max_time = valid_len / self.chunk_size + 1
|
| 208 |
+
time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item())
|
| 209 |
+
time_steps = torch.tensor(time_steps, device=device) # (B,)
|
| 210 |
+
noise_level = self._get_noise_levels(device, seq_len, time_steps) # (B, T)
|
| 211 |
+
|
| 212 |
+
# # Debug: Print noise levels
|
| 213 |
+
# print("Time steps and corresponding noise levels:")
|
| 214 |
+
# for i in range(batch_size):
|
| 215 |
+
# t = time_steps[i].item()
|
| 216 |
+
# # Get noise level at each position
|
| 217 |
+
# start_idx = int(self.chunk_size * (t - 1))
|
| 218 |
+
# end_idx = int(self.chunk_size * t) + 2
|
| 219 |
+
# # Limit to valid range
|
| 220 |
+
# start_idx = max(0, start_idx)
|
| 221 |
+
# end_idx = min(seq_len, end_idx)
|
| 222 |
+
# print(time_steps[i])
|
| 223 |
+
# print(noise_level[i, start_idx:end_idx])
|
| 224 |
+
|
| 225 |
+
# Add noise to entire sequence
|
| 226 |
+
noisy_feature, noise = self.add_noise(feature, noise_level) # (B, T, D)
|
| 227 |
+
|
| 228 |
+
# Debug: Print noise addition information
|
| 229 |
+
# print("Added noise levels at chunk positions:")
|
| 230 |
+
# for i in range(batch_size):
|
| 231 |
+
# t = time_steps[i].item()
|
| 232 |
+
# start_idx = int(self.chunk_size * (t - 1))
|
| 233 |
+
# end_idx = int(self.chunk_size * t) + 2
|
| 234 |
+
# # Limit to valid range
|
| 235 |
+
# start_idx = max(0, start_idx)
|
| 236 |
+
# end_idx = min(seq_len, end_idx)
|
| 237 |
+
# test1 = (
|
| 238 |
+
# feature[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
|
| 239 |
+
# )
|
| 240 |
+
# test2 = (
|
| 241 |
+
# noise[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
|
| 242 |
+
# )
|
| 243 |
+
# # Compute length on last dimension
|
| 244 |
+
# print(test1.norm(dim=-1))
|
| 245 |
+
# print(test2.norm(dim=-1))
|
| 246 |
+
|
| 247 |
+
feature = self.preprocess(feature) # (B, C, T, 1, 1)
|
| 248 |
+
noisy_feature = self.preprocess(noisy_feature) # (B, C, T, 1, 1)
|
| 249 |
+
noise = self.preprocess(noise) # (B, C, T, 1, 1)
|
| 250 |
+
|
| 251 |
+
feature_ref = []
|
| 252 |
+
noise_ref = []
|
| 253 |
+
noisy_feature_input = []
|
| 254 |
+
for i in range(batch_size):
|
| 255 |
+
t = time_steps[i].item()
|
| 256 |
+
end_index = int(self.chunk_size * t) + 1
|
| 257 |
+
valid_len = feature_length[i].item()
|
| 258 |
+
end_index = min(valid_len, end_index)
|
| 259 |
+
feature_ref.append(feature[i, :, :end_index, ...])
|
| 260 |
+
noise_ref.append(noise[i, :, :end_index, ...])
|
| 261 |
+
noisy_feature_input.append(noisy_feature[i, :, :end_index, ...])
|
| 262 |
+
|
| 263 |
+
# Encode text condition (using cache)
|
| 264 |
+
if self.use_text_cond and "text" in x:
|
| 265 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 266 |
+
if isinstance(text_list[0], list):
|
| 267 |
+
text_end_list = x["feature_text_end"]
|
| 268 |
+
all_text_context = []
|
| 269 |
+
for single_text_list, single_text_end_list in zip(
|
| 270 |
+
text_list, text_end_list
|
| 271 |
+
):
|
| 272 |
+
if np.random.rand() > self.drop_out:
|
| 273 |
+
single_text_end_list = [0] + [
|
| 274 |
+
min(t, seq_len) for t in single_text_end_list
|
| 275 |
+
]
|
| 276 |
+
else:
|
| 277 |
+
single_text_list = [""]
|
| 278 |
+
single_text_end_list = [0, seq_len]
|
| 279 |
+
single_text_length_list = [
|
| 280 |
+
t - b
|
| 281 |
+
for t, b in zip(
|
| 282 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 283 |
+
)
|
| 284 |
+
]
|
| 285 |
+
single_text_context = self.encode_text_with_cache(
|
| 286 |
+
single_text_list, device
|
| 287 |
+
)
|
| 288 |
+
single_text_context = [
|
| 289 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 290 |
+
]
|
| 291 |
+
for u, duration in zip(
|
| 292 |
+
single_text_context, single_text_length_list
|
| 293 |
+
):
|
| 294 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 295 |
+
all_text_context.extend(
|
| 296 |
+
[
|
| 297 |
+
single_text_context[-1]
|
| 298 |
+
for _ in range(seq_len - single_text_end_list[-1])
|
| 299 |
+
]
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
all_text_context = [
|
| 303 |
+
(u if np.random.rand() > self.drop_out else "") for u in text_list
|
| 304 |
+
]
|
| 305 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 306 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 307 |
+
else:
|
| 308 |
+
all_text_context = [""] * batch_size
|
| 309 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 310 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 311 |
+
|
| 312 |
+
# Through WanModel
|
| 313 |
+
predicted_result = self.model(
|
| 314 |
+
noisy_feature_input,
|
| 315 |
+
noise_level * self.time_embedding_scale,
|
| 316 |
+
all_text_context,
|
| 317 |
+
seq_len,
|
| 318 |
+
y=None,
|
| 319 |
+
) # (B, C, T, 1, 1)
|
| 320 |
+
|
| 321 |
+
loss = 0.0
|
| 322 |
+
for b in range(batch_size):
|
| 323 |
+
if self.prediction_type == "vel":
|
| 324 |
+
vel = feature_ref[b] - noise_ref[b] # (C, input_length, 1, 1)
|
| 325 |
+
squared_error = (
|
| 326 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 327 |
+
- vel[:, -self.chunk_size :, ...]
|
| 328 |
+
) ** 2
|
| 329 |
+
elif self.prediction_type == "x0":
|
| 330 |
+
squared_error = (
|
| 331 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 332 |
+
- feature_ref[b][:, -self.chunk_size :, ...]
|
| 333 |
+
) ** 2
|
| 334 |
+
elif self.prediction_type == "noise":
|
| 335 |
+
squared_error = (
|
| 336 |
+
predicted_result[b][:, -self.chunk_size :, ...]
|
| 337 |
+
- noise_ref[b][:, -self.chunk_size :, ...]
|
| 338 |
+
) ** 2
|
| 339 |
+
sample_loss = squared_error.sum().mean()
|
| 340 |
+
loss += sample_loss
|
| 341 |
+
loss = loss / batch_size
|
| 342 |
+
|
| 343 |
+
loss_dict = {"total": loss, "mse": loss}
|
| 344 |
+
return loss_dict
|
| 345 |
+
|
| 346 |
+
def generate(self, x, num_denoise_steps=None):
|
| 347 |
+
"""
|
| 348 |
+
Generation - Diffusion Forcing inference
|
| 349 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 350 |
+
|
| 351 |
+
Generation process:
|
| 352 |
+
1. Start from t=0, gradually increase t
|
| 353 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 354 |
+
3. After each denoising step, t increases slightly and continues
|
| 355 |
+
"""
|
| 356 |
+
feature_length = x["feature_length"]
|
| 357 |
+
batch_size = len(feature_length)
|
| 358 |
+
seq_len = max(feature_length).item()
|
| 359 |
+
|
| 360 |
+
# # debug
|
| 361 |
+
# x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
|
| 362 |
+
# x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
|
| 363 |
+
# text = x["text"]
|
| 364 |
+
# text_end = x["feature_text_end"]
|
| 365 |
+
# print(text)
|
| 366 |
+
# print(text_end)
|
| 367 |
+
# print(batch_size, seq_len, self.chunk_size)
|
| 368 |
+
|
| 369 |
+
if num_denoise_steps is None:
|
| 370 |
+
num_denoise_steps = self.noise_steps
|
| 371 |
+
assert num_denoise_steps % self.chunk_size == 0
|
| 372 |
+
|
| 373 |
+
device = next(self.parameters()).device
|
| 374 |
+
|
| 375 |
+
# Initialize entire sequence as pure noise
|
| 376 |
+
generated = torch.randn(
|
| 377 |
+
batch_size, seq_len + self.chunk_size, self.input_dim, device=device
|
| 378 |
+
)
|
| 379 |
+
generated = self.preprocess(generated) # (B, C, T, 1, 1)
|
| 380 |
+
|
| 381 |
+
# Calculate total number of time steps needed
|
| 382 |
+
max_t = 1 + (seq_len - 1) / self.chunk_size
|
| 383 |
+
|
| 384 |
+
# Step size for each advancement
|
| 385 |
+
dt = 1 / num_denoise_steps
|
| 386 |
+
total_steps = int(max_t / dt)
|
| 387 |
+
|
| 388 |
+
# Encode text condition (using cache)
|
| 389 |
+
if self.use_text_cond and "text" in x:
|
| 390 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 391 |
+
if isinstance(text_list[0], list):
|
| 392 |
+
generated_length = []
|
| 393 |
+
text_end_list = x["feature_text_end"]
|
| 394 |
+
full_text = []
|
| 395 |
+
all_text_context = []
|
| 396 |
+
for single_text_list, single_text_end_list in zip(
|
| 397 |
+
text_list, text_end_list
|
| 398 |
+
):
|
| 399 |
+
single_text_end_list = [0] + [
|
| 400 |
+
min(t, seq_len) for t in single_text_end_list
|
| 401 |
+
]
|
| 402 |
+
generated_length.append(single_text_end_list[-1])
|
| 403 |
+
single_text_length_list = [
|
| 404 |
+
t - b
|
| 405 |
+
for t, b in zip(
|
| 406 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 407 |
+
)
|
| 408 |
+
]
|
| 409 |
+
full_text.append(
|
| 410 |
+
" ////////// ".join(
|
| 411 |
+
[
|
| 412 |
+
f"{u} //dur:{t}"
|
| 413 |
+
for u, t in zip(
|
| 414 |
+
single_text_list, single_text_length_list
|
| 415 |
+
)
|
| 416 |
+
]
|
| 417 |
+
)
|
| 418 |
+
)
|
| 419 |
+
single_text_context = self.encode_text_with_cache(
|
| 420 |
+
single_text_list, device
|
| 421 |
+
)
|
| 422 |
+
single_text_context = [
|
| 423 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 424 |
+
]
|
| 425 |
+
for u, duration in zip(
|
| 426 |
+
single_text_context, single_text_length_list
|
| 427 |
+
):
|
| 428 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 429 |
+
all_text_context.extend(
|
| 430 |
+
[
|
| 431 |
+
single_text_context[-1]
|
| 432 |
+
for _ in range(
|
| 433 |
+
seq_len + self.chunk_size - single_text_end_list[-1]
|
| 434 |
+
)
|
| 435 |
+
]
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
generated_length = feature_length
|
| 439 |
+
full_text = text_list
|
| 440 |
+
all_text_context = self.encode_text_with_cache(text_list, device)
|
| 441 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 442 |
+
else:
|
| 443 |
+
generated_length = feature_length
|
| 444 |
+
full_text = [""] * batch_size
|
| 445 |
+
all_text_context = [""] * batch_size
|
| 446 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 447 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 448 |
+
|
| 449 |
+
# Get empty text condition encoding (for CFG)
|
| 450 |
+
text_null_list = [""] * batch_size
|
| 451 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 452 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 453 |
+
|
| 454 |
+
# print(len(all_text_context), len(text_null_context))
|
| 455 |
+
|
| 456 |
+
# Progressively advance from t=0 to t=max_t
|
| 457 |
+
for step in range(total_steps):
|
| 458 |
+
# Current time step
|
| 459 |
+
t = step * dt
|
| 460 |
+
start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
|
| 461 |
+
end_index = int(self.chunk_size * t) + 1
|
| 462 |
+
time_steps = torch.full((batch_size,), t, device=device)
|
| 463 |
+
|
| 464 |
+
# Calculate current noise schedule
|
| 465 |
+
noise_level = self._get_noise_levels(
|
| 466 |
+
device, seq_len + self.chunk_size, time_steps
|
| 467 |
+
) # (B, T)
|
| 468 |
+
|
| 469 |
+
# Predict noise through WanModel
|
| 470 |
+
noisy_input = []
|
| 471 |
+
for i in range(batch_size):
|
| 472 |
+
noisy_input.append(generated[i, :, :end_index, ...])
|
| 473 |
+
|
| 474 |
+
predicted_result = self.model(
|
| 475 |
+
noisy_input,
|
| 476 |
+
noise_level * self.time_embedding_scale,
|
| 477 |
+
all_text_context,
|
| 478 |
+
seq_len + self.chunk_size,
|
| 479 |
+
y=None,
|
| 480 |
+
) # (B, C, T, 1, 1)
|
| 481 |
+
|
| 482 |
+
# Adjust using CFG
|
| 483 |
+
if self.cfg_scale != 1.0:
|
| 484 |
+
predicted_result_null = self.model(
|
| 485 |
+
noisy_input,
|
| 486 |
+
noise_level * self.time_embedding_scale,
|
| 487 |
+
text_null_context,
|
| 488 |
+
seq_len + self.chunk_size,
|
| 489 |
+
y=None,
|
| 490 |
+
) # (B, C, T, 1, 1)
|
| 491 |
+
predicted_result = [
|
| 492 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 493 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
for i in range(batch_size):
|
| 497 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 498 |
+
if self.prediction_type == "vel":
|
| 499 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 500 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 501 |
+
elif self.prediction_type == "x0":
|
| 502 |
+
predicted_vel = (
|
| 503 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 504 |
+
- generated[i, :, start_index:end_index, ...]
|
| 505 |
+
) / (
|
| 506 |
+
noise_level[i, start_index:end_index]
|
| 507 |
+
.unsqueeze(0)
|
| 508 |
+
.unsqueeze(-1)
|
| 509 |
+
.unsqueeze(-1)
|
| 510 |
+
)
|
| 511 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 512 |
+
elif self.prediction_type == "noise":
|
| 513 |
+
predicted_vel = (
|
| 514 |
+
generated[i, :, start_index:end_index, ...]
|
| 515 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 516 |
+
) / (
|
| 517 |
+
1
|
| 518 |
+
+ dt
|
| 519 |
+
- noise_level[i, start_index:end_index]
|
| 520 |
+
.unsqueeze(0)
|
| 521 |
+
.unsqueeze(-1)
|
| 522 |
+
.unsqueeze(-1)
|
| 523 |
+
)
|
| 524 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 525 |
+
|
| 526 |
+
generated = self.postprocess(generated) # (B, T, C)
|
| 527 |
+
y_hat_out = []
|
| 528 |
+
for i in range(batch_size):
|
| 529 |
+
# cut off the padding
|
| 530 |
+
single_generated = generated[i, : generated_length[i], :]
|
| 531 |
+
y_hat_out.append(single_generated)
|
| 532 |
+
out = {}
|
| 533 |
+
out["generated"] = y_hat_out
|
| 534 |
+
out["text"] = full_text
|
| 535 |
+
|
| 536 |
+
return out
|
| 537 |
+
|
| 538 |
+
@torch.no_grad()
|
| 539 |
+
def stream_generate(self, x, num_denoise_steps=None):
|
| 540 |
+
"""
|
| 541 |
+
Streaming generation - Diffusion Forcing inference
|
| 542 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 543 |
+
|
| 544 |
+
Generation process:
|
| 545 |
+
1. Start from t=0, gradually increase t
|
| 546 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 547 |
+
3. After each denoising step, t increases slightly and continues
|
| 548 |
+
"""
|
| 549 |
+
feature_length = x["feature_length"]
|
| 550 |
+
batch_size = len(feature_length)
|
| 551 |
+
seq_len = max(feature_length).item()
|
| 552 |
+
|
| 553 |
+
# # debug
|
| 554 |
+
# x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
|
| 555 |
+
# x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
|
| 556 |
+
# text = x["text"]
|
| 557 |
+
# text_end = x["feature_text_end"]
|
| 558 |
+
# print(text)
|
| 559 |
+
# print(text_end)
|
| 560 |
+
# print(batch_size, seq_len, self.chunk_size)
|
| 561 |
+
|
| 562 |
+
if num_denoise_steps is None:
|
| 563 |
+
num_denoise_steps = self.noise_steps
|
| 564 |
+
assert num_denoise_steps % self.chunk_size == 0
|
| 565 |
+
|
| 566 |
+
device = next(self.parameters()).device
|
| 567 |
+
|
| 568 |
+
# Initialize entire sequence as pure noise
|
| 569 |
+
generated = torch.randn(
|
| 570 |
+
batch_size, seq_len + self.chunk_size, self.input_dim, device=device
|
| 571 |
+
)
|
| 572 |
+
generated = self.preprocess(generated) # (B, C, T, 1, 1)
|
| 573 |
+
|
| 574 |
+
# Calculate total number of time steps needed
|
| 575 |
+
max_t = 1 + (seq_len - 1) / self.chunk_size
|
| 576 |
+
|
| 577 |
+
# Step size for each advancement
|
| 578 |
+
dt = 1 / num_denoise_steps
|
| 579 |
+
total_steps = int(max_t / dt)
|
| 580 |
+
|
| 581 |
+
# Encode text condition (using cache)
|
| 582 |
+
if self.use_text_cond and "text" in x:
|
| 583 |
+
text_list = x["text"] # List[str] or List[List[str]]
|
| 584 |
+
if isinstance(text_list[0], list):
|
| 585 |
+
generated_length = []
|
| 586 |
+
text_end_list = x["feature_text_end"]
|
| 587 |
+
full_text = []
|
| 588 |
+
all_text_context = []
|
| 589 |
+
for single_text_list, single_text_end_list in zip(
|
| 590 |
+
text_list, text_end_list
|
| 591 |
+
):
|
| 592 |
+
single_text_end_list = [0] + [
|
| 593 |
+
min(t, seq_len) for t in single_text_end_list
|
| 594 |
+
]
|
| 595 |
+
generated_length.append(single_text_end_list[-1])
|
| 596 |
+
single_text_length_list = [
|
| 597 |
+
t - b
|
| 598 |
+
for t, b in zip(
|
| 599 |
+
single_text_end_list[1:], single_text_end_list[:-1]
|
| 600 |
+
)
|
| 601 |
+
]
|
| 602 |
+
full_text.append(
|
| 603 |
+
" ////////// ".join(
|
| 604 |
+
[
|
| 605 |
+
f"{u} //dur:{t}"
|
| 606 |
+
for u, t in zip(
|
| 607 |
+
single_text_list, single_text_length_list
|
| 608 |
+
)
|
| 609 |
+
]
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
single_text_context = self.encode_text_with_cache(
|
| 613 |
+
single_text_list, device
|
| 614 |
+
)
|
| 615 |
+
single_text_context = [
|
| 616 |
+
u.to(self.param_dtype) for u in single_text_context
|
| 617 |
+
]
|
| 618 |
+
for u, duration in zip(
|
| 619 |
+
single_text_context, single_text_length_list
|
| 620 |
+
):
|
| 621 |
+
all_text_context.extend([u for _ in range(duration)])
|
| 622 |
+
all_text_context.extend(
|
| 623 |
+
[
|
| 624 |
+
single_text_context[-1]
|
| 625 |
+
for _ in range(
|
| 626 |
+
seq_len + self.chunk_size - single_text_end_list[-1]
|
| 627 |
+
)
|
| 628 |
+
]
|
| 629 |
+
)
|
| 630 |
+
else:
|
| 631 |
+
generated_length = feature_length
|
| 632 |
+
full_text = text_list
|
| 633 |
+
all_text_context = self.encode_text_with_cache(text_list, device)
|
| 634 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 635 |
+
else:
|
| 636 |
+
generated_length = feature_length
|
| 637 |
+
full_text = [""] * batch_size
|
| 638 |
+
all_text_context = [""] * batch_size
|
| 639 |
+
all_text_context = self.encode_text_with_cache(all_text_context, device)
|
| 640 |
+
all_text_context = [u.to(self.param_dtype) for u in all_text_context]
|
| 641 |
+
|
| 642 |
+
# Get empty text condition encoding (for CFG)
|
| 643 |
+
text_null_list = [""] * batch_size
|
| 644 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 645 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 646 |
+
|
| 647 |
+
# print(len(all_text_context), len(text_null_context))
|
| 648 |
+
|
| 649 |
+
commit_index = 0
|
| 650 |
+
# Progressively advance from t=0 to t=max_t
|
| 651 |
+
for step in range(total_steps):
|
| 652 |
+
# Current time step
|
| 653 |
+
t = step * dt
|
| 654 |
+
start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
|
| 655 |
+
end_index = int(self.chunk_size * t) + 1
|
| 656 |
+
time_steps = torch.full((batch_size,), t, device=device)
|
| 657 |
+
|
| 658 |
+
# Calculate current noise schedule
|
| 659 |
+
noise_level = self._get_noise_levels(
|
| 660 |
+
device, seq_len + self.chunk_size, time_steps
|
| 661 |
+
) # (B, T)
|
| 662 |
+
|
| 663 |
+
# Predict noise through WanModel
|
| 664 |
+
noisy_input = []
|
| 665 |
+
for i in range(batch_size):
|
| 666 |
+
noisy_input.append(generated[i, :, :end_index, ...])
|
| 667 |
+
|
| 668 |
+
predicted_result = self.model(
|
| 669 |
+
noisy_input,
|
| 670 |
+
noise_level * self.time_embedding_scale,
|
| 671 |
+
all_text_context,
|
| 672 |
+
seq_len + self.chunk_size,
|
| 673 |
+
y=None,
|
| 674 |
+
) # (B, C, T, 1, 1)
|
| 675 |
+
|
| 676 |
+
# Adjust using CFG
|
| 677 |
+
if self.cfg_scale != 1.0:
|
| 678 |
+
predicted_result_null = self.model(
|
| 679 |
+
noisy_input,
|
| 680 |
+
noise_level * self.time_embedding_scale,
|
| 681 |
+
text_null_context,
|
| 682 |
+
seq_len + self.chunk_size,
|
| 683 |
+
y=None,
|
| 684 |
+
) # (B, C, T, 1, 1)
|
| 685 |
+
predicted_result = [
|
| 686 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 687 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 688 |
+
]
|
| 689 |
+
|
| 690 |
+
for i in range(batch_size):
|
| 691 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 692 |
+
if self.prediction_type == "vel":
|
| 693 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 694 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 695 |
+
elif self.prediction_type == "x0":
|
| 696 |
+
predicted_vel = (
|
| 697 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 698 |
+
- generated[i, :, start_index:end_index, ...]
|
| 699 |
+
) / (
|
| 700 |
+
noise_level[i, start_index:end_index]
|
| 701 |
+
.unsqueeze(0)
|
| 702 |
+
.unsqueeze(-1)
|
| 703 |
+
.unsqueeze(-1)
|
| 704 |
+
)
|
| 705 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 706 |
+
elif self.prediction_type == "noise":
|
| 707 |
+
predicted_vel = (
|
| 708 |
+
generated[i, :, start_index:end_index, ...]
|
| 709 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 710 |
+
) / (
|
| 711 |
+
1
|
| 712 |
+
+ dt
|
| 713 |
+
- noise_level[i, start_index:end_index]
|
| 714 |
+
.unsqueeze(0)
|
| 715 |
+
.unsqueeze(-1)
|
| 716 |
+
.unsqueeze(-1)
|
| 717 |
+
)
|
| 718 |
+
generated[i, :, start_index:end_index, ...] += predicted_vel * dt
|
| 719 |
+
|
| 720 |
+
if commit_index < start_index:
|
| 721 |
+
output = generated[:, :, commit_index:start_index, ...]
|
| 722 |
+
output = self.postprocess(output) # (B, T, C)
|
| 723 |
+
y_hat_out = []
|
| 724 |
+
for i in range(batch_size):
|
| 725 |
+
if commit_index < generated_length[i]:
|
| 726 |
+
y_hat_out.append(
|
| 727 |
+
output[i, : generated_length[i] - commit_index, ...]
|
| 728 |
+
)
|
| 729 |
+
else:
|
| 730 |
+
y_hat_out.append(None)
|
| 731 |
+
|
| 732 |
+
out = {}
|
| 733 |
+
out["generated"] = y_hat_out
|
| 734 |
+
yield out
|
| 735 |
+
commit_index = start_index
|
| 736 |
+
|
| 737 |
+
output = generated[:, :, commit_index:, ...]
|
| 738 |
+
output = self.postprocess(output) # (B, T_remain, C)
|
| 739 |
+
y_hat_out = []
|
| 740 |
+
for i in range(batch_size):
|
| 741 |
+
if commit_index < generated_length[i]:
|
| 742 |
+
y_hat_out.append(output[i, : generated_length[i] - commit_index, ...])
|
| 743 |
+
else:
|
| 744 |
+
y_hat_out.append(None)
|
| 745 |
+
out = {}
|
| 746 |
+
out["generated"] = y_hat_out
|
| 747 |
+
yield out
|
| 748 |
+
|
| 749 |
+
def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None):
|
| 750 |
+
self.seq_len = seq_len
|
| 751 |
+
self.batch_size = batch_size
|
| 752 |
+
if num_denoise_steps is None:
|
| 753 |
+
self.num_denoise_steps = self.noise_steps
|
| 754 |
+
else:
|
| 755 |
+
self.num_denoise_steps = num_denoise_steps
|
| 756 |
+
assert self.num_denoise_steps % self.chunk_size == 0
|
| 757 |
+
self.dt = 1 / self.num_denoise_steps
|
| 758 |
+
self.current_step = 0
|
| 759 |
+
self.text_condition_list = [[] for _ in range(self.batch_size)]
|
| 760 |
+
self.generated = torch.randn(
|
| 761 |
+
self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim
|
| 762 |
+
)
|
| 763 |
+
self.generated = self.preprocess(self.generated) # (B, C, T, 1, 1)
|
| 764 |
+
self.commit_index = 0
|
| 765 |
+
|
| 766 |
+
@torch.no_grad()
|
| 767 |
+
def stream_generate_step(self, x, first_chunk=True):
|
| 768 |
+
"""
|
| 769 |
+
Streaming generation step - Diffusion Forcing inference
|
| 770 |
+
Uses triangular noise schedule, progressively generating from left to right
|
| 771 |
+
|
| 772 |
+
Generation process:
|
| 773 |
+
1. Start from t=0, gradually increase t
|
| 774 |
+
2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
|
| 775 |
+
3. After each denoising step, t increases slightly and continues
|
| 776 |
+
"""
|
| 777 |
+
|
| 778 |
+
device = next(self.parameters()).device
|
| 779 |
+
if first_chunk:
|
| 780 |
+
self.generated = self.generated.to(device)
|
| 781 |
+
|
| 782 |
+
# Encode text condition (using cache)
|
| 783 |
+
if self.use_text_cond and "text" in x:
|
| 784 |
+
text_list = x["text"] # List[str]
|
| 785 |
+
new_text_context = self.encode_text_with_cache(text_list, device)
|
| 786 |
+
new_text_context = [u.to(self.param_dtype) for u in new_text_context]
|
| 787 |
+
else:
|
| 788 |
+
new_text_context = [""] * self.batch_size
|
| 789 |
+
new_text_context = self.encode_text_with_cache(new_text_context, device)
|
| 790 |
+
new_text_context = [u.to(self.param_dtype) for u in new_text_context]
|
| 791 |
+
|
| 792 |
+
# Get empty text condition encoding (for CFG)
|
| 793 |
+
text_null_list = [""] * self.batch_size
|
| 794 |
+
text_null_context = self.encode_text_with_cache(text_null_list, device)
|
| 795 |
+
text_null_context = [u.to(self.param_dtype) for u in text_null_context]
|
| 796 |
+
|
| 797 |
+
for i in range(self.batch_size):
|
| 798 |
+
if first_chunk:
|
| 799 |
+
self.text_condition_list[i].extend(
|
| 800 |
+
[new_text_context[i]] * self.chunk_size
|
| 801 |
+
)
|
| 802 |
+
else:
|
| 803 |
+
self.text_condition_list[i].extend([new_text_context[i]])
|
| 804 |
+
|
| 805 |
+
end_step = (
|
| 806 |
+
(self.commit_index + self.chunk_size)
|
| 807 |
+
* self.num_denoise_steps
|
| 808 |
+
/ self.chunk_size
|
| 809 |
+
)
|
| 810 |
+
while self.current_step < end_step:
|
| 811 |
+
current_time = self.current_step * self.dt
|
| 812 |
+
start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1)
|
| 813 |
+
end_index = int(self.chunk_size * current_time) + 1
|
| 814 |
+
time_steps = torch.full((self.batch_size,), current_time, device=device)
|
| 815 |
+
|
| 816 |
+
noise_level = self._get_noise_levels(device, end_index, time_steps)[
|
| 817 |
+
:, -self.seq_len :
|
| 818 |
+
] # (B, T)
|
| 819 |
+
|
| 820 |
+
# Predict noise through WanModel
|
| 821 |
+
noisy_input = []
|
| 822 |
+
for i in range(self.batch_size):
|
| 823 |
+
noisy_input.append(
|
| 824 |
+
self.generated[i, :, :end_index, ...][:, -self.seq_len :]
|
| 825 |
+
) # (C, T, 1, 1)
|
| 826 |
+
|
| 827 |
+
text_condition = []
|
| 828 |
+
for i in range(self.batch_size):
|
| 829 |
+
text_condition.extend(
|
| 830 |
+
self.text_condition_list[i][:end_index][-self.seq_len :]
|
| 831 |
+
) # (T, D, 4096)
|
| 832 |
+
|
| 833 |
+
# print("////////////////////")
|
| 834 |
+
# print("current step: ", self.current_step)
|
| 835 |
+
# print("chunk size: ", self.chunk_size)
|
| 836 |
+
# print("start_index: ", start_index)
|
| 837 |
+
# print("end_index: ", end_index)
|
| 838 |
+
# print("noisy_input shape: ", noisy_input[0].shape)
|
| 839 |
+
# print("noise_level: ", noise_level[0, start_index:end_index])
|
| 840 |
+
# print("text_condition shape: ", len(text_condition))
|
| 841 |
+
# print("commit_index: ", self.commit_index)
|
| 842 |
+
# print("////////////////////")
|
| 843 |
+
|
| 844 |
+
predicted_result = self.model(
|
| 845 |
+
noisy_input,
|
| 846 |
+
noise_level * self.time_embedding_scale,
|
| 847 |
+
text_condition,
|
| 848 |
+
min(end_index, self.seq_len),
|
| 849 |
+
y=None,
|
| 850 |
+
) # (B, C, T, 1, 1)
|
| 851 |
+
|
| 852 |
+
# Adjust using CFG
|
| 853 |
+
if self.cfg_scale != 1.0:
|
| 854 |
+
predicted_result_null = self.model(
|
| 855 |
+
noisy_input,
|
| 856 |
+
noise_level * self.time_embedding_scale,
|
| 857 |
+
text_null_context,
|
| 858 |
+
min(end_index, self.seq_len),
|
| 859 |
+
y=None,
|
| 860 |
+
) # (B, C, T, 1, 1)
|
| 861 |
+
predicted_result = [
|
| 862 |
+
self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
|
| 863 |
+
for pv, pvn in zip(predicted_result, predicted_result_null)
|
| 864 |
+
]
|
| 865 |
+
|
| 866 |
+
for i in range(self.batch_size):
|
| 867 |
+
predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
|
| 868 |
+
if end_index > self.seq_len:
|
| 869 |
+
predicted_result_i = torch.cat(
|
| 870 |
+
[
|
| 871 |
+
torch.zeros(
|
| 872 |
+
predicted_result_i.shape[0],
|
| 873 |
+
end_index - self.seq_len,
|
| 874 |
+
predicted_result_i.shape[2],
|
| 875 |
+
predicted_result_i.shape[3],
|
| 876 |
+
device=device,
|
| 877 |
+
),
|
| 878 |
+
predicted_result_i,
|
| 879 |
+
],
|
| 880 |
+
dim=1,
|
| 881 |
+
)
|
| 882 |
+
if self.prediction_type == "vel":
|
| 883 |
+
predicted_vel = predicted_result_i[:, start_index:end_index, ...]
|
| 884 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 885 |
+
predicted_vel * self.dt
|
| 886 |
+
)
|
| 887 |
+
elif self.prediction_type == "x0":
|
| 888 |
+
predicted_vel = (
|
| 889 |
+
predicted_result_i[:, start_index:end_index, ...]
|
| 890 |
+
- self.generated[i, :, start_index:end_index, ...]
|
| 891 |
+
) / (
|
| 892 |
+
noise_level[i, start_index:end_index]
|
| 893 |
+
.unsqueeze(0)
|
| 894 |
+
.unsqueeze(-1)
|
| 895 |
+
.unsqueeze(-1)
|
| 896 |
+
)
|
| 897 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 898 |
+
predicted_vel * self.dt
|
| 899 |
+
)
|
| 900 |
+
elif self.prediction_type == "noise":
|
| 901 |
+
predicted_vel = (
|
| 902 |
+
self.generated[i, :, start_index:end_index, ...]
|
| 903 |
+
- predicted_result_i[:, start_index:end_index, ...]
|
| 904 |
+
) / (
|
| 905 |
+
1
|
| 906 |
+
+ self.dt
|
| 907 |
+
- noise_level[i, start_index:end_index]
|
| 908 |
+
.unsqueeze(0)
|
| 909 |
+
.unsqueeze(-1)
|
| 910 |
+
.unsqueeze(-1)
|
| 911 |
+
)
|
| 912 |
+
self.generated[i, :, start_index:end_index, ...] += (
|
| 913 |
+
predicted_vel * self.dt
|
| 914 |
+
)
|
| 915 |
+
self.current_step += 1
|
| 916 |
+
output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...]
|
| 917 |
+
output = self.postprocess(output) # (B, 1, C)
|
| 918 |
+
out = {}
|
| 919 |
+
out["generated"] = output
|
| 920 |
+
self.commit_index += 1
|
| 921 |
+
|
| 922 |
+
if self.commit_index == self.seq_len * 2:
|
| 923 |
+
self.generated = torch.cat(
|
| 924 |
+
[
|
| 925 |
+
self.generated[:, :, self.seq_len :, ...],
|
| 926 |
+
torch.randn(
|
| 927 |
+
self.batch_size,
|
| 928 |
+
self.input_dim,
|
| 929 |
+
self.seq_len,
|
| 930 |
+
1,
|
| 931 |
+
1,
|
| 932 |
+
device=device,
|
| 933 |
+
),
|
| 934 |
+
],
|
| 935 |
+
dim=2,
|
| 936 |
+
)
|
| 937 |
+
self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size
|
| 938 |
+
self.commit_index -= self.seq_len
|
| 939 |
+
for i in range(self.batch_size):
|
| 940 |
+
self.text_condition_list[i] = self.text_condition_list[i][
|
| 941 |
+
self.seq_len :
|
| 942 |
+
]
|
| 943 |
+
return out
|
ldf_models/tools/attention.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import flash_attn_interface
|
| 6 |
+
|
| 7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 8 |
+
except ModuleNotFoundError:
|
| 9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import flash_attn
|
| 13 |
+
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"flash_attention",
|
| 22 |
+
"attention",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def flash_attention(
|
| 27 |
+
q,
|
| 28 |
+
k,
|
| 29 |
+
v,
|
| 30 |
+
q_lens=None,
|
| 31 |
+
k_lens=None,
|
| 32 |
+
dropout_p=0.0,
|
| 33 |
+
softmax_scale=None,
|
| 34 |
+
q_scale=None,
|
| 35 |
+
causal=False,
|
| 36 |
+
window_size=(-1, -1),
|
| 37 |
+
deterministic=False,
|
| 38 |
+
dtype=torch.bfloat16,
|
| 39 |
+
version=None,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
q: [B, Lq, Nq, C1].
|
| 43 |
+
k: [B, Lk, Nk, C1].
|
| 44 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 45 |
+
q_lens: [B].
|
| 46 |
+
k_lens: [B].
|
| 47 |
+
dropout_p: float. Dropout probability.
|
| 48 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 49 |
+
causal: bool. Whether to apply causal attention mask.
|
| 50 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 51 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 52 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 53 |
+
"""
|
| 54 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 55 |
+
assert dtype in half_dtypes
|
| 56 |
+
assert q.device.type == "cuda" and q.size(-1) <= 256
|
| 57 |
+
|
| 58 |
+
# params
|
| 59 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 60 |
+
|
| 61 |
+
def half(x):
|
| 62 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 63 |
+
|
| 64 |
+
# preprocess query
|
| 65 |
+
if q_lens is None:
|
| 66 |
+
q = half(q.flatten(0, 1))
|
| 67 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
|
| 68 |
+
device=q.device, non_blocking=True
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 72 |
+
|
| 73 |
+
# preprocess key, value
|
| 74 |
+
if k_lens is None:
|
| 75 |
+
k = half(k.flatten(0, 1))
|
| 76 |
+
v = half(v.flatten(0, 1))
|
| 77 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
|
| 78 |
+
device=k.device, non_blocking=True
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 82 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 83 |
+
|
| 84 |
+
q = q.to(v.dtype)
|
| 85 |
+
k = k.to(v.dtype)
|
| 86 |
+
|
| 87 |
+
if q_scale is not None:
|
| 88 |
+
q = q * q_scale
|
| 89 |
+
|
| 90 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 91 |
+
warnings.warn(
|
| 92 |
+
"Flash attention 3 is not available, use flash attention 2 instead."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# apply attention
|
| 96 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 97 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 98 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 99 |
+
q=q,
|
| 100 |
+
k=k,
|
| 101 |
+
v=v,
|
| 102 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
| 103 |
+
.cumsum(0, dtype=torch.int32)
|
| 104 |
+
.to(q.device, non_blocking=True),
|
| 105 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
| 106 |
+
.cumsum(0, dtype=torch.int32)
|
| 107 |
+
.to(q.device, non_blocking=True),
|
| 108 |
+
seqused_q=None,
|
| 109 |
+
seqused_k=None,
|
| 110 |
+
max_seqlen_q=lq,
|
| 111 |
+
max_seqlen_k=lk,
|
| 112 |
+
softmax_scale=softmax_scale,
|
| 113 |
+
causal=causal,
|
| 114 |
+
deterministic=deterministic,
|
| 115 |
+
)[0].unflatten(0, (b, lq))
|
| 116 |
+
else:
|
| 117 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 118 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 119 |
+
q=q,
|
| 120 |
+
k=k,
|
| 121 |
+
v=v,
|
| 122 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
| 123 |
+
.cumsum(0, dtype=torch.int32)
|
| 124 |
+
.to(q.device, non_blocking=True),
|
| 125 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
| 126 |
+
.cumsum(0, dtype=torch.int32)
|
| 127 |
+
.to(q.device, non_blocking=True),
|
| 128 |
+
max_seqlen_q=lq,
|
| 129 |
+
max_seqlen_k=lk,
|
| 130 |
+
dropout_p=dropout_p,
|
| 131 |
+
softmax_scale=softmax_scale,
|
| 132 |
+
causal=causal,
|
| 133 |
+
window_size=window_size,
|
| 134 |
+
deterministic=deterministic,
|
| 135 |
+
).unflatten(0, (b, lq))
|
| 136 |
+
|
| 137 |
+
# output
|
| 138 |
+
return x.type(out_dtype)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def attention(
|
| 142 |
+
q,
|
| 143 |
+
k,
|
| 144 |
+
v,
|
| 145 |
+
q_lens=None,
|
| 146 |
+
k_lens=None,
|
| 147 |
+
dropout_p=0.0,
|
| 148 |
+
softmax_scale=None,
|
| 149 |
+
q_scale=None,
|
| 150 |
+
causal=False,
|
| 151 |
+
window_size=(-1, -1),
|
| 152 |
+
deterministic=False,
|
| 153 |
+
dtype=torch.bfloat16,
|
| 154 |
+
fa_version=None,
|
| 155 |
+
):
|
| 156 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 157 |
+
return flash_attention(
|
| 158 |
+
q=q,
|
| 159 |
+
k=k,
|
| 160 |
+
v=v,
|
| 161 |
+
q_lens=q_lens,
|
| 162 |
+
k_lens=k_lens,
|
| 163 |
+
dropout_p=dropout_p,
|
| 164 |
+
softmax_scale=softmax_scale,
|
| 165 |
+
q_scale=q_scale,
|
| 166 |
+
causal=causal,
|
| 167 |
+
window_size=window_size,
|
| 168 |
+
deterministic=deterministic,
|
| 169 |
+
dtype=dtype,
|
| 170 |
+
version=fa_version,
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
if q_lens is not None or k_lens is not None:
|
| 174 |
+
warnings.warn(
|
| 175 |
+
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
| 176 |
+
)
|
| 177 |
+
attn_mask = None
|
| 178 |
+
|
| 179 |
+
q = q.transpose(1, 2).to(dtype)
|
| 180 |
+
k = k.transpose(1, 2).to(dtype)
|
| 181 |
+
v = v.transpose(1, 2).to(dtype)
|
| 182 |
+
|
| 183 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 184 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
out = out.transpose(1, 2).contiguous()
|
| 188 |
+
return out
|
ldf_models/tools/t5.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.t5.modeling_t5
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"T5Model",
|
| 14 |
+
"T5Encoder",
|
| 15 |
+
"T5Decoder",
|
| 16 |
+
"T5EncoderModel",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fp16_clamp(x):
|
| 21 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 22 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 23 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_weights(m):
|
| 28 |
+
if isinstance(m, T5LayerNorm):
|
| 29 |
+
nn.init.ones_(m.weight)
|
| 30 |
+
elif isinstance(m, T5Model):
|
| 31 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
| 32 |
+
elif isinstance(m, T5FeedForward):
|
| 33 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 34 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 35 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 36 |
+
elif isinstance(m, T5Attention):
|
| 37 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
| 38 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 39 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 40 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
| 41 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 42 |
+
nn.init.normal_(
|
| 43 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GELU(nn.Module):
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return (
|
| 50 |
+
0.5
|
| 51 |
+
* x
|
| 52 |
+
* (
|
| 53 |
+
1.0
|
| 54 |
+
+ torch.tanh(
|
| 55 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
| 56 |
+
)
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class T5LayerNorm(nn.Module):
|
| 62 |
+
def __init__(self, dim, eps=1e-6):
|
| 63 |
+
super(T5LayerNorm, self).__init__()
|
| 64 |
+
self.dim = dim
|
| 65 |
+
self.eps = eps
|
| 66 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 70 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 71 |
+
x = x.type_as(self.weight)
|
| 72 |
+
return self.weight * x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class T5Attention(nn.Module):
|
| 76 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 77 |
+
assert dim_attn % num_heads == 0
|
| 78 |
+
super(T5Attention, self).__init__()
|
| 79 |
+
self.dim = dim
|
| 80 |
+
self.dim_attn = dim_attn
|
| 81 |
+
self.num_heads = num_heads
|
| 82 |
+
self.head_dim = dim_attn // num_heads
|
| 83 |
+
|
| 84 |
+
# layers
|
| 85 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 86 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 87 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 88 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 89 |
+
self.dropout = nn.Dropout(dropout)
|
| 90 |
+
|
| 91 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 92 |
+
"""
|
| 93 |
+
x: [B, L1, C].
|
| 94 |
+
context: [B, L2, C] or None.
|
| 95 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 96 |
+
"""
|
| 97 |
+
# check inputs
|
| 98 |
+
context = x if context is None else context
|
| 99 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 100 |
+
|
| 101 |
+
# compute query, key, value
|
| 102 |
+
q = self.q(x).view(b, -1, n, c)
|
| 103 |
+
k = self.k(context).view(b, -1, n, c)
|
| 104 |
+
v = self.v(context).view(b, -1, n, c)
|
| 105 |
+
|
| 106 |
+
# attention bias
|
| 107 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 108 |
+
if pos_bias is not None:
|
| 109 |
+
attn_bias += pos_bias
|
| 110 |
+
if mask is not None:
|
| 111 |
+
assert mask.ndim in [2, 3]
|
| 112 |
+
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 113 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 114 |
+
|
| 115 |
+
# compute attention (T5 does not use scaling)
|
| 116 |
+
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
| 117 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 118 |
+
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
| 119 |
+
|
| 120 |
+
# output
|
| 121 |
+
x = x.reshape(b, -1, n * c)
|
| 122 |
+
x = self.o(x)
|
| 123 |
+
x = self.dropout(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class T5FeedForward(nn.Module):
|
| 128 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 129 |
+
super(T5FeedForward, self).__init__()
|
| 130 |
+
self.dim = dim
|
| 131 |
+
self.dim_ffn = dim_ffn
|
| 132 |
+
|
| 133 |
+
# layers
|
| 134 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 135 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 136 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 137 |
+
self.dropout = nn.Dropout(dropout)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.fc1(x) * self.gate(x)
|
| 141 |
+
x = self.dropout(x)
|
| 142 |
+
x = self.fc2(x)
|
| 143 |
+
x = self.dropout(x)
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class T5SelfAttention(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim,
|
| 151 |
+
dim_attn,
|
| 152 |
+
dim_ffn,
|
| 153 |
+
num_heads,
|
| 154 |
+
num_buckets,
|
| 155 |
+
shared_pos=True,
|
| 156 |
+
dropout=0.1,
|
| 157 |
+
):
|
| 158 |
+
super(T5SelfAttention, self).__init__()
|
| 159 |
+
self.dim = dim
|
| 160 |
+
self.dim_attn = dim_attn
|
| 161 |
+
self.dim_ffn = dim_ffn
|
| 162 |
+
self.num_heads = num_heads
|
| 163 |
+
self.num_buckets = num_buckets
|
| 164 |
+
self.shared_pos = shared_pos
|
| 165 |
+
|
| 166 |
+
# layers
|
| 167 |
+
self.norm1 = T5LayerNorm(dim)
|
| 168 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 169 |
+
self.norm2 = T5LayerNorm(dim)
|
| 170 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 171 |
+
self.pos_embedding = (
|
| 172 |
+
None
|
| 173 |
+
if shared_pos
|
| 174 |
+
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 178 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
| 179 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 180 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class T5CrossAttention(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
dim,
|
| 188 |
+
dim_attn,
|
| 189 |
+
dim_ffn,
|
| 190 |
+
num_heads,
|
| 191 |
+
num_buckets,
|
| 192 |
+
shared_pos=True,
|
| 193 |
+
dropout=0.1,
|
| 194 |
+
):
|
| 195 |
+
super(T5CrossAttention, self).__init__()
|
| 196 |
+
self.dim = dim
|
| 197 |
+
self.dim_attn = dim_attn
|
| 198 |
+
self.dim_ffn = dim_ffn
|
| 199 |
+
self.num_heads = num_heads
|
| 200 |
+
self.num_buckets = num_buckets
|
| 201 |
+
self.shared_pos = shared_pos
|
| 202 |
+
|
| 203 |
+
# layers
|
| 204 |
+
self.norm1 = T5LayerNorm(dim)
|
| 205 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 206 |
+
self.norm2 = T5LayerNorm(dim)
|
| 207 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 208 |
+
self.norm3 = T5LayerNorm(dim)
|
| 209 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 210 |
+
self.pos_embedding = (
|
| 211 |
+
None
|
| 212 |
+
if shared_pos
|
| 213 |
+
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward(
|
| 217 |
+
self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
|
| 218 |
+
):
|
| 219 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
| 220 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 221 |
+
x = fp16_clamp(
|
| 222 |
+
x
|
| 223 |
+
+ self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
|
| 224 |
+
)
|
| 225 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 226 |
+
return x
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class T5RelativeEmbedding(nn.Module):
|
| 230 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 231 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 232 |
+
self.num_buckets = num_buckets
|
| 233 |
+
self.num_heads = num_heads
|
| 234 |
+
self.bidirectional = bidirectional
|
| 235 |
+
self.max_dist = max_dist
|
| 236 |
+
|
| 237 |
+
# layers
|
| 238 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 239 |
+
|
| 240 |
+
def forward(self, lq, lk):
|
| 241 |
+
device = self.embedding.weight.device
|
| 242 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 243 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 244 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
|
| 245 |
+
lq, device=device
|
| 246 |
+
).unsqueeze(1)
|
| 247 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 248 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 249 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
| 250 |
+
return rel_pos_embeds.contiguous()
|
| 251 |
+
|
| 252 |
+
def _relative_position_bucket(self, rel_pos):
|
| 253 |
+
# preprocess
|
| 254 |
+
if self.bidirectional:
|
| 255 |
+
num_buckets = self.num_buckets // 2
|
| 256 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 257 |
+
rel_pos = torch.abs(rel_pos)
|
| 258 |
+
else:
|
| 259 |
+
num_buckets = self.num_buckets
|
| 260 |
+
rel_buckets = 0
|
| 261 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 262 |
+
|
| 263 |
+
# embeddings for small and large positions
|
| 264 |
+
max_exact = num_buckets // 2
|
| 265 |
+
rel_pos_large = (
|
| 266 |
+
max_exact
|
| 267 |
+
+ (
|
| 268 |
+
torch.log(rel_pos.float() / max_exact)
|
| 269 |
+
/ math.log(self.max_dist / max_exact)
|
| 270 |
+
* (num_buckets - max_exact)
|
| 271 |
+
).long()
|
| 272 |
+
)
|
| 273 |
+
rel_pos_large = torch.min(
|
| 274 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
|
| 275 |
+
)
|
| 276 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 277 |
+
return rel_buckets
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class T5Encoder(nn.Module):
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
vocab,
|
| 284 |
+
dim,
|
| 285 |
+
dim_attn,
|
| 286 |
+
dim_ffn,
|
| 287 |
+
num_heads,
|
| 288 |
+
num_layers,
|
| 289 |
+
num_buckets,
|
| 290 |
+
shared_pos=True,
|
| 291 |
+
dropout=0.1,
|
| 292 |
+
):
|
| 293 |
+
super(T5Encoder, self).__init__()
|
| 294 |
+
self.dim = dim
|
| 295 |
+
self.dim_attn = dim_attn
|
| 296 |
+
self.dim_ffn = dim_ffn
|
| 297 |
+
self.num_heads = num_heads
|
| 298 |
+
self.num_layers = num_layers
|
| 299 |
+
self.num_buckets = num_buckets
|
| 300 |
+
self.shared_pos = shared_pos
|
| 301 |
+
|
| 302 |
+
# layers
|
| 303 |
+
self.token_embedding = (
|
| 304 |
+
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
| 305 |
+
)
|
| 306 |
+
self.pos_embedding = (
|
| 307 |
+
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
| 308 |
+
if shared_pos
|
| 309 |
+
else None
|
| 310 |
+
)
|
| 311 |
+
self.dropout = nn.Dropout(dropout)
|
| 312 |
+
self.blocks = nn.ModuleList(
|
| 313 |
+
[
|
| 314 |
+
T5SelfAttention(
|
| 315 |
+
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
|
| 316 |
+
)
|
| 317 |
+
for _ in range(num_layers)
|
| 318 |
+
]
|
| 319 |
+
)
|
| 320 |
+
self.norm = T5LayerNorm(dim)
|
| 321 |
+
|
| 322 |
+
# initialize weights
|
| 323 |
+
self.apply(init_weights)
|
| 324 |
+
|
| 325 |
+
def forward(self, ids, mask=None):
|
| 326 |
+
x = self.token_embedding(ids)
|
| 327 |
+
x = self.dropout(x)
|
| 328 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
| 329 |
+
for block in self.blocks:
|
| 330 |
+
x = block(x, mask, pos_bias=e)
|
| 331 |
+
x = self.norm(x)
|
| 332 |
+
x = self.dropout(x)
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class T5Decoder(nn.Module):
|
| 337 |
+
def __init__(
|
| 338 |
+
self,
|
| 339 |
+
vocab,
|
| 340 |
+
dim,
|
| 341 |
+
dim_attn,
|
| 342 |
+
dim_ffn,
|
| 343 |
+
num_heads,
|
| 344 |
+
num_layers,
|
| 345 |
+
num_buckets,
|
| 346 |
+
shared_pos=True,
|
| 347 |
+
dropout=0.1,
|
| 348 |
+
):
|
| 349 |
+
super(T5Decoder, self).__init__()
|
| 350 |
+
self.dim = dim
|
| 351 |
+
self.dim_attn = dim_attn
|
| 352 |
+
self.dim_ffn = dim_ffn
|
| 353 |
+
self.num_heads = num_heads
|
| 354 |
+
self.num_layers = num_layers
|
| 355 |
+
self.num_buckets = num_buckets
|
| 356 |
+
self.shared_pos = shared_pos
|
| 357 |
+
|
| 358 |
+
# layers
|
| 359 |
+
self.token_embedding = (
|
| 360 |
+
vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
| 361 |
+
)
|
| 362 |
+
self.pos_embedding = (
|
| 363 |
+
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
| 364 |
+
if shared_pos
|
| 365 |
+
else None
|
| 366 |
+
)
|
| 367 |
+
self.dropout = nn.Dropout(dropout)
|
| 368 |
+
self.blocks = nn.ModuleList(
|
| 369 |
+
[
|
| 370 |
+
T5CrossAttention(
|
| 371 |
+
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
|
| 372 |
+
)
|
| 373 |
+
for _ in range(num_layers)
|
| 374 |
+
]
|
| 375 |
+
)
|
| 376 |
+
self.norm = T5LayerNorm(dim)
|
| 377 |
+
|
| 378 |
+
# initialize weights
|
| 379 |
+
self.apply(init_weights)
|
| 380 |
+
|
| 381 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
| 382 |
+
b, s = ids.size()
|
| 383 |
+
|
| 384 |
+
# causal mask
|
| 385 |
+
if mask is None:
|
| 386 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
| 387 |
+
elif mask.ndim == 2:
|
| 388 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
| 389 |
+
|
| 390 |
+
# layers
|
| 391 |
+
x = self.token_embedding(ids)
|
| 392 |
+
x = self.dropout(x)
|
| 393 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
| 394 |
+
for block in self.blocks:
|
| 395 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
| 396 |
+
x = self.norm(x)
|
| 397 |
+
x = self.dropout(x)
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class T5Model(nn.Module):
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
vocab_size,
|
| 405 |
+
dim,
|
| 406 |
+
dim_attn,
|
| 407 |
+
dim_ffn,
|
| 408 |
+
num_heads,
|
| 409 |
+
encoder_layers,
|
| 410 |
+
decoder_layers,
|
| 411 |
+
num_buckets,
|
| 412 |
+
shared_pos=True,
|
| 413 |
+
dropout=0.1,
|
| 414 |
+
):
|
| 415 |
+
super(T5Model, self).__init__()
|
| 416 |
+
self.vocab_size = vocab_size
|
| 417 |
+
self.dim = dim
|
| 418 |
+
self.dim_attn = dim_attn
|
| 419 |
+
self.dim_ffn = dim_ffn
|
| 420 |
+
self.num_heads = num_heads
|
| 421 |
+
self.encoder_layers = encoder_layers
|
| 422 |
+
self.decoder_layers = decoder_layers
|
| 423 |
+
self.num_buckets = num_buckets
|
| 424 |
+
|
| 425 |
+
# layers
|
| 426 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
| 427 |
+
self.encoder = T5Encoder(
|
| 428 |
+
self.token_embedding,
|
| 429 |
+
dim,
|
| 430 |
+
dim_attn,
|
| 431 |
+
dim_ffn,
|
| 432 |
+
num_heads,
|
| 433 |
+
encoder_layers,
|
| 434 |
+
num_buckets,
|
| 435 |
+
shared_pos,
|
| 436 |
+
dropout,
|
| 437 |
+
)
|
| 438 |
+
self.decoder = T5Decoder(
|
| 439 |
+
self.token_embedding,
|
| 440 |
+
dim,
|
| 441 |
+
dim_attn,
|
| 442 |
+
dim_ffn,
|
| 443 |
+
num_heads,
|
| 444 |
+
decoder_layers,
|
| 445 |
+
num_buckets,
|
| 446 |
+
shared_pos,
|
| 447 |
+
dropout,
|
| 448 |
+
)
|
| 449 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 450 |
+
|
| 451 |
+
# initialize weights
|
| 452 |
+
self.apply(init_weights)
|
| 453 |
+
|
| 454 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
| 455 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
| 456 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
| 457 |
+
x = self.head(x)
|
| 458 |
+
return x
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _t5(
|
| 462 |
+
name,
|
| 463 |
+
encoder_only=False,
|
| 464 |
+
decoder_only=False,
|
| 465 |
+
return_tokenizer=False,
|
| 466 |
+
tokenizer_kwargs={},
|
| 467 |
+
dtype=torch.float32,
|
| 468 |
+
device="cpu",
|
| 469 |
+
**kwargs,
|
| 470 |
+
):
|
| 471 |
+
# sanity check
|
| 472 |
+
assert not (encoder_only and decoder_only)
|
| 473 |
+
|
| 474 |
+
# params
|
| 475 |
+
if encoder_only:
|
| 476 |
+
model_cls = T5Encoder
|
| 477 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
| 478 |
+
kwargs["num_layers"] = kwargs.pop("encoder_layers")
|
| 479 |
+
_ = kwargs.pop("decoder_layers")
|
| 480 |
+
elif decoder_only:
|
| 481 |
+
model_cls = T5Decoder
|
| 482 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
| 483 |
+
kwargs["num_layers"] = kwargs.pop("decoder_layers")
|
| 484 |
+
_ = kwargs.pop("encoder_layers")
|
| 485 |
+
else:
|
| 486 |
+
model_cls = T5Model
|
| 487 |
+
|
| 488 |
+
# init model
|
| 489 |
+
with torch.device(device):
|
| 490 |
+
model = model_cls(**kwargs)
|
| 491 |
+
|
| 492 |
+
# set device
|
| 493 |
+
model = model.to(dtype=dtype, device=device)
|
| 494 |
+
|
| 495 |
+
# init tokenizer
|
| 496 |
+
if return_tokenizer:
|
| 497 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 498 |
+
|
| 499 |
+
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
|
| 500 |
+
return model, tokenizer
|
| 501 |
+
else:
|
| 502 |
+
return model
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def umt5_xxl(**kwargs):
|
| 506 |
+
cfg = dict(
|
| 507 |
+
vocab_size=256384,
|
| 508 |
+
dim=4096,
|
| 509 |
+
dim_attn=4096,
|
| 510 |
+
dim_ffn=10240,
|
| 511 |
+
num_heads=64,
|
| 512 |
+
encoder_layers=24,
|
| 513 |
+
decoder_layers=24,
|
| 514 |
+
num_buckets=32,
|
| 515 |
+
shared_pos=False,
|
| 516 |
+
dropout=0.1,
|
| 517 |
+
)
|
| 518 |
+
cfg.update(**kwargs)
|
| 519 |
+
return _t5("umt5-xxl", **cfg)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class T5EncoderModel:
|
| 523 |
+
def __init__(
|
| 524 |
+
self,
|
| 525 |
+
text_len,
|
| 526 |
+
dtype=torch.bfloat16,
|
| 527 |
+
device=torch.cuda.current_device(),
|
| 528 |
+
checkpoint_path=None,
|
| 529 |
+
tokenizer_path=None,
|
| 530 |
+
shard_fn=None,
|
| 531 |
+
):
|
| 532 |
+
self.text_len = text_len
|
| 533 |
+
self.dtype = dtype
|
| 534 |
+
self.device = device
|
| 535 |
+
self.checkpoint_path = checkpoint_path
|
| 536 |
+
self.tokenizer_path = tokenizer_path
|
| 537 |
+
|
| 538 |
+
# init model
|
| 539 |
+
model = (
|
| 540 |
+
umt5_xxl(
|
| 541 |
+
encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
|
| 542 |
+
)
|
| 543 |
+
.eval()
|
| 544 |
+
.requires_grad_(False)
|
| 545 |
+
)
|
| 546 |
+
logging.info(f"loading {checkpoint_path}")
|
| 547 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
| 548 |
+
self.model = model
|
| 549 |
+
if shard_fn is not None:
|
| 550 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
| 551 |
+
else:
|
| 552 |
+
self.model.to(self.device)
|
| 553 |
+
# init tokenizer
|
| 554 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 555 |
+
name=tokenizer_path, seq_len=text_len, clean="whitespace"
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def __call__(self, texts, device):
|
| 559 |
+
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
|
| 560 |
+
ids = ids.to(device)
|
| 561 |
+
mask = mask.to(device)
|
| 562 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 563 |
+
context = self.model(ids, mask)
|
| 564 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|
ldf_models/tools/tokenizers.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import html
|
| 3 |
+
import string
|
| 4 |
+
|
| 5 |
+
import ftfy
|
| 6 |
+
import regex as re
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
__all__ = ["HuggingfaceTokenizer"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def basic_clean(text):
|
| 13 |
+
text = ftfy.fix_text(text)
|
| 14 |
+
text = html.unescape(html.unescape(text))
|
| 15 |
+
return text.strip()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def whitespace_clean(text):
|
| 19 |
+
text = re.sub(r"\s+", " ", text)
|
| 20 |
+
text = text.strip()
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 25 |
+
text = text.replace("_", " ")
|
| 26 |
+
if keep_punctuation_exact_string:
|
| 27 |
+
text = keep_punctuation_exact_string.join(
|
| 28 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
| 29 |
+
for part in text.split(keep_punctuation_exact_string)
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 33 |
+
text = text.lower()
|
| 34 |
+
text = re.sub(r"\s+", " ", text)
|
| 35 |
+
return text.strip()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class HuggingfaceTokenizer:
|
| 39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 40 |
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
| 41 |
+
self.name = name
|
| 42 |
+
self.seq_len = seq_len
|
| 43 |
+
self.clean = clean
|
| 44 |
+
|
| 45 |
+
# init tokenizer
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 48 |
+
|
| 49 |
+
def __call__(self, sequence, **kwargs):
|
| 50 |
+
return_mask = kwargs.pop("return_mask", False)
|
| 51 |
+
|
| 52 |
+
# arguments
|
| 53 |
+
_kwargs = {"return_tensors": "pt"}
|
| 54 |
+
if self.seq_len is not None:
|
| 55 |
+
_kwargs.update(
|
| 56 |
+
{
|
| 57 |
+
"padding": "max_length",
|
| 58 |
+
"truncation": True,
|
| 59 |
+
"max_length": self.seq_len,
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
_kwargs.update(**kwargs)
|
| 63 |
+
|
| 64 |
+
# tokenization
|
| 65 |
+
if isinstance(sequence, str):
|
| 66 |
+
sequence = [sequence]
|
| 67 |
+
if self.clean:
|
| 68 |
+
sequence = [self._clean(u) for u in sequence]
|
| 69 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 70 |
+
|
| 71 |
+
# output
|
| 72 |
+
if return_mask:
|
| 73 |
+
return ids.input_ids, ids.attention_mask
|
| 74 |
+
else:
|
| 75 |
+
return ids.input_ids
|
| 76 |
+
|
| 77 |
+
def _clean(self, text):
|
| 78 |
+
if self.clean == "whitespace":
|
| 79 |
+
text = whitespace_clean(basic_clean(text))
|
| 80 |
+
elif self.clean == "lower":
|
| 81 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 82 |
+
elif self.clean == "canonicalize":
|
| 83 |
+
text = canonicalize(basic_clean(text))
|
| 84 |
+
return text
|
ldf_models/tools/wan_model.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module uses modified code from Alibaba Wan Team
|
| 2 |
+
# Original source: https://github.com/Wan-Video/Wan2.2
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
# Modified to support stream mode for cross-attention.
|
| 5 |
+
# Added causal attention for self-attention (1d case)
|
| 6 |
+
# Added context length corrrection.
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
|
| 15 |
+
from .attention import flash_attention
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 19 |
+
# preprocess
|
| 20 |
+
assert dim % 2 == 0
|
| 21 |
+
half = dim // 2
|
| 22 |
+
position = position.type(torch.float64)
|
| 23 |
+
|
| 24 |
+
# calculation
|
| 25 |
+
sinusoid = torch.outer(
|
| 26 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half))
|
| 27 |
+
)
|
| 28 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 33 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 34 |
+
assert dim % 2 == 0
|
| 35 |
+
freqs = torch.outer(
|
| 36 |
+
torch.arange(max_seq_len),
|
| 37 |
+
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
|
| 38 |
+
)
|
| 39 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 40 |
+
return freqs
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 44 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 45 |
+
n, c = x.size(2), x.size(3) // 2
|
| 46 |
+
|
| 47 |
+
# split freqs
|
| 48 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 49 |
+
|
| 50 |
+
# loop over samples
|
| 51 |
+
output = []
|
| 52 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 53 |
+
seq_len = f * h * w
|
| 54 |
+
|
| 55 |
+
# precompute multipliers
|
| 56 |
+
x_i = torch.view_as_complex(
|
| 57 |
+
x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
|
| 58 |
+
)
|
| 59 |
+
freqs_i = torch.cat(
|
| 60 |
+
[
|
| 61 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 62 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 63 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
| 64 |
+
],
|
| 65 |
+
dim=-1,
|
| 66 |
+
).reshape(seq_len, 1, -1)
|
| 67 |
+
|
| 68 |
+
# apply rotary embedding
|
| 69 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 70 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 71 |
+
|
| 72 |
+
# append to collection
|
| 73 |
+
output.append(x_i)
|
| 74 |
+
return torch.stack(output).float()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class WanRMSNorm(nn.Module):
|
| 78 |
+
def __init__(self, dim, eps=1e-5):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.dim = dim
|
| 81 |
+
self.eps = eps
|
| 82 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
r"""
|
| 86 |
+
Args:
|
| 87 |
+
x(Tensor): Shape [B, L, C]
|
| 88 |
+
"""
|
| 89 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 90 |
+
|
| 91 |
+
def _norm(self, x):
|
| 92 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 96 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 97 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
r"""
|
| 101 |
+
Args:
|
| 102 |
+
x(Tensor): Shape [B, L, C]
|
| 103 |
+
"""
|
| 104 |
+
return super().forward(x.float()).type_as(x)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class WanSelfAttention(nn.Module):
|
| 108 |
+
def __init__(
|
| 109 |
+
self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, causal=False
|
| 110 |
+
):
|
| 111 |
+
assert dim % num_heads == 0
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.num_heads = num_heads
|
| 115 |
+
self.head_dim = dim // num_heads
|
| 116 |
+
self.window_size = window_size
|
| 117 |
+
self.qk_norm = qk_norm
|
| 118 |
+
self.eps = eps
|
| 119 |
+
self.causal = causal
|
| 120 |
+
# layers
|
| 121 |
+
self.q = nn.Linear(dim, dim)
|
| 122 |
+
self.k = nn.Linear(dim, dim)
|
| 123 |
+
self.v = nn.Linear(dim, dim)
|
| 124 |
+
self.o = nn.Linear(dim, dim)
|
| 125 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 126 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 127 |
+
|
| 128 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 129 |
+
r"""
|
| 130 |
+
Args:
|
| 131 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 132 |
+
seq_lens(Tensor): Shape [B]
|
| 133 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 134 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 135 |
+
"""
|
| 136 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 137 |
+
|
| 138 |
+
# query, key, value function
|
| 139 |
+
def qkv_fn(x):
|
| 140 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 141 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 142 |
+
v = self.v(x).view(b, s, n, d)
|
| 143 |
+
return q, k, v
|
| 144 |
+
|
| 145 |
+
q, k, v = qkv_fn(x)
|
| 146 |
+
|
| 147 |
+
x = flash_attention(
|
| 148 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 149 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 150 |
+
v=v,
|
| 151 |
+
k_lens=seq_lens,
|
| 152 |
+
window_size=self.window_size,
|
| 153 |
+
causal=self.causal,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# output
|
| 157 |
+
x = x.flatten(2)
|
| 158 |
+
x = self.o(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WanCrossAttention(WanSelfAttention):
|
| 163 |
+
def forward(self, x, context, context_lens):
|
| 164 |
+
r"""
|
| 165 |
+
Args non-stream mode:
|
| 166 |
+
x(Tensor): Shape [B, L1, C]
|
| 167 |
+
context(Tensor): Shape [B, L2, C]
|
| 168 |
+
context_lens(Tensor): Shape [B]
|
| 169 |
+
Args stream mode:
|
| 170 |
+
x(Tensor): Shape [B, L1, C]
|
| 171 |
+
context(Tensor): Shape [BxL1, L2, C]
|
| 172 |
+
context_lens(Tensor): Shape [BxL1]
|
| 173 |
+
"""
|
| 174 |
+
out_sizes = x.size()
|
| 175 |
+
b, n, d = context.size(0), self.num_heads, self.head_dim
|
| 176 |
+
|
| 177 |
+
# compute query, key, value
|
| 178 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 179 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 180 |
+
v = self.v(context).view(b, -1, n, d)
|
| 181 |
+
|
| 182 |
+
# compute attention
|
| 183 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 184 |
+
|
| 185 |
+
# output
|
| 186 |
+
x = x.flatten(2).view(*out_sizes)
|
| 187 |
+
x = self.o(x)
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class WanAttentionBlock(nn.Module):
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
dim,
|
| 195 |
+
ffn_dim,
|
| 196 |
+
num_heads,
|
| 197 |
+
window_size=(-1, -1),
|
| 198 |
+
qk_norm=True,
|
| 199 |
+
cross_attn_norm=False,
|
| 200 |
+
eps=1e-6,
|
| 201 |
+
causal=False,
|
| 202 |
+
):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.dim = dim
|
| 205 |
+
self.ffn_dim = ffn_dim
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
self.window_size = window_size
|
| 208 |
+
self.qk_norm = qk_norm
|
| 209 |
+
self.cross_attn_norm = cross_attn_norm
|
| 210 |
+
self.eps = eps
|
| 211 |
+
self.causal = causal
|
| 212 |
+
# layers
|
| 213 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 214 |
+
self.self_attn = WanSelfAttention(
|
| 215 |
+
dim, num_heads, window_size, qk_norm, eps, causal
|
| 216 |
+
)
|
| 217 |
+
self.norm3 = (
|
| 218 |
+
WanLayerNorm(dim, eps, elementwise_affine=True)
|
| 219 |
+
if cross_attn_norm
|
| 220 |
+
else nn.Identity()
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
|
| 224 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 225 |
+
self.ffn = nn.Sequential(
|
| 226 |
+
nn.Linear(dim, ffn_dim),
|
| 227 |
+
nn.GELU(approximate="tanh"),
|
| 228 |
+
nn.Linear(ffn_dim, dim),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# modulation
|
| 232 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 233 |
+
|
| 234 |
+
def forward(
|
| 235 |
+
self,
|
| 236 |
+
x,
|
| 237 |
+
e,
|
| 238 |
+
seq_lens,
|
| 239 |
+
grid_sizes,
|
| 240 |
+
freqs,
|
| 241 |
+
context,
|
| 242 |
+
context_lens,
|
| 243 |
+
):
|
| 244 |
+
r"""
|
| 245 |
+
Args:
|
| 246 |
+
x(Tensor): Shape [B, L, C]
|
| 247 |
+
e(Tensor): Shape [B, L1, 6, C]
|
| 248 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 249 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 250 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 251 |
+
"""
|
| 252 |
+
assert e.dtype == torch.float32
|
| 253 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 254 |
+
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
| 255 |
+
assert e[0].dtype == torch.float32
|
| 256 |
+
|
| 257 |
+
# self-attention
|
| 258 |
+
y = self.self_attn(
|
| 259 |
+
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
|
| 260 |
+
seq_lens,
|
| 261 |
+
grid_sizes,
|
| 262 |
+
freqs,
|
| 263 |
+
)
|
| 264 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 265 |
+
x = x + y * e[2].squeeze(2)
|
| 266 |
+
|
| 267 |
+
# cross-attention & ffn function
|
| 268 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 269 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 270 |
+
y = self.ffn(
|
| 271 |
+
self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)
|
| 272 |
+
)
|
| 273 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 274 |
+
x = x + y * e[5].squeeze(2)
|
| 275 |
+
return x
|
| 276 |
+
|
| 277 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Head(nn.Module):
|
| 282 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.dim = dim
|
| 285 |
+
self.out_dim = out_dim
|
| 286 |
+
self.patch_size = patch_size
|
| 287 |
+
self.eps = eps
|
| 288 |
+
|
| 289 |
+
# layers
|
| 290 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 291 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 292 |
+
self.head = nn.Linear(dim, out_dim)
|
| 293 |
+
|
| 294 |
+
# modulation
|
| 295 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 296 |
+
|
| 297 |
+
def forward(self, x, e):
|
| 298 |
+
r"""
|
| 299 |
+
Args:
|
| 300 |
+
x(Tensor): Shape [B, L1, C]
|
| 301 |
+
e(Tensor): Shape [B, L1, C]
|
| 302 |
+
"""
|
| 303 |
+
assert e.dtype == torch.float32
|
| 304 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 305 |
+
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
| 306 |
+
x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 311 |
+
r"""
|
| 312 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
ignore_for_config = [
|
| 316 |
+
"patch_size",
|
| 317 |
+
"cross_attn_norm",
|
| 318 |
+
"qk_norm",
|
| 319 |
+
"text_dim",
|
| 320 |
+
"window_size",
|
| 321 |
+
]
|
| 322 |
+
_no_split_modules = ["WanAttentionBlock"]
|
| 323 |
+
|
| 324 |
+
@register_to_config
|
| 325 |
+
def __init__(
|
| 326 |
+
self,
|
| 327 |
+
model_type="t2v",
|
| 328 |
+
patch_size=(1, 2, 2),
|
| 329 |
+
text_len=512,
|
| 330 |
+
in_dim=16,
|
| 331 |
+
dim=2048,
|
| 332 |
+
ffn_dim=8192,
|
| 333 |
+
freq_dim=256,
|
| 334 |
+
text_dim=4096,
|
| 335 |
+
out_dim=16,
|
| 336 |
+
num_heads=16,
|
| 337 |
+
num_layers=32,
|
| 338 |
+
window_size=(-1, -1),
|
| 339 |
+
qk_norm=True,
|
| 340 |
+
cross_attn_norm=True,
|
| 341 |
+
eps=1e-6,
|
| 342 |
+
causal=False,
|
| 343 |
+
):
|
| 344 |
+
r"""
|
| 345 |
+
Initialize the diffusion model backbone.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 349 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 350 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 351 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 352 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 353 |
+
Fixed length for text embeddings
|
| 354 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 355 |
+
Input video channels (C_in)
|
| 356 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 357 |
+
Hidden dimension of the transformer
|
| 358 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 359 |
+
Intermediate dimension in feed-forward network
|
| 360 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 361 |
+
Dimension for sinusoidal time embeddings
|
| 362 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 363 |
+
Input dimension for text embeddings
|
| 364 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 365 |
+
Output video channels (C_out)
|
| 366 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 367 |
+
Number of attention heads
|
| 368 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 369 |
+
Number of transformer blocks
|
| 370 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 371 |
+
Window size for local attention (-1 indicates global attention)
|
| 372 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 373 |
+
Enable query/key normalization
|
| 374 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 375 |
+
Enable cross-attention normalization
|
| 376 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 377 |
+
Epsilon value for normalization layers
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
super().__init__()
|
| 381 |
+
|
| 382 |
+
assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
|
| 383 |
+
self.model_type = model_type
|
| 384 |
+
|
| 385 |
+
self.patch_size = patch_size
|
| 386 |
+
self.text_len = text_len
|
| 387 |
+
self.in_dim = in_dim
|
| 388 |
+
self.dim = dim
|
| 389 |
+
self.ffn_dim = ffn_dim
|
| 390 |
+
self.freq_dim = freq_dim
|
| 391 |
+
self.text_dim = text_dim
|
| 392 |
+
self.out_dim = out_dim
|
| 393 |
+
self.num_heads = num_heads
|
| 394 |
+
self.num_layers = num_layers
|
| 395 |
+
self.window_size = window_size
|
| 396 |
+
self.qk_norm = qk_norm
|
| 397 |
+
self.cross_attn_norm = cross_attn_norm
|
| 398 |
+
self.eps = eps
|
| 399 |
+
self.causal = causal
|
| 400 |
+
# embeddings
|
| 401 |
+
self.patch_embedding = nn.Conv3d(
|
| 402 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size
|
| 403 |
+
)
|
| 404 |
+
self.text_embedding = nn.Sequential(
|
| 405 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
self.time_embedding = nn.Sequential(
|
| 409 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
| 410 |
+
)
|
| 411 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 412 |
+
|
| 413 |
+
# blocks
|
| 414 |
+
self.blocks = nn.ModuleList(
|
| 415 |
+
[
|
| 416 |
+
WanAttentionBlock(
|
| 417 |
+
dim,
|
| 418 |
+
ffn_dim,
|
| 419 |
+
num_heads,
|
| 420 |
+
window_size,
|
| 421 |
+
qk_norm,
|
| 422 |
+
cross_attn_norm,
|
| 423 |
+
eps,
|
| 424 |
+
causal,
|
| 425 |
+
)
|
| 426 |
+
for _ in range(num_layers)
|
| 427 |
+
]
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# head
|
| 431 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 432 |
+
|
| 433 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 434 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 435 |
+
d = dim // num_heads
|
| 436 |
+
self.freqs = torch.cat(
|
| 437 |
+
[
|
| 438 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 439 |
+
rope_params(1024, 2 * (d // 6)),
|
| 440 |
+
rope_params(1024, 2 * (d // 6)),
|
| 441 |
+
],
|
| 442 |
+
dim=1,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# initialize weights
|
| 446 |
+
self.init_weights()
|
| 447 |
+
|
| 448 |
+
def forward(
|
| 449 |
+
self,
|
| 450 |
+
x,
|
| 451 |
+
t,
|
| 452 |
+
context,
|
| 453 |
+
seq_len,
|
| 454 |
+
y=None,
|
| 455 |
+
):
|
| 456 |
+
r"""
|
| 457 |
+
Forward pass through the diffusion model
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
x (List[Tensor]):
|
| 461 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 462 |
+
t (Tensor):
|
| 463 |
+
Diffusion timesteps tensor of shape [B]
|
| 464 |
+
context (List[Tensor]):
|
| 465 |
+
List of text embeddings each with shape [L, C]
|
| 466 |
+
seq_len (`int`):
|
| 467 |
+
Maximum sequence length for positional encoding
|
| 468 |
+
y (List[Tensor], *optional*):
|
| 469 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
List[Tensor]:
|
| 473 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 474 |
+
"""
|
| 475 |
+
if self.model_type == "i2v":
|
| 476 |
+
assert y is not None
|
| 477 |
+
# params
|
| 478 |
+
device = self.patch_embedding.weight.device
|
| 479 |
+
if self.freqs.device != device:
|
| 480 |
+
self.freqs = self.freqs.to(device)
|
| 481 |
+
|
| 482 |
+
if y is not None:
|
| 483 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 484 |
+
|
| 485 |
+
# embeddings
|
| 486 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 487 |
+
grid_sizes = torch.stack(
|
| 488 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
|
| 489 |
+
)
|
| 490 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 491 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 492 |
+
assert seq_lens.max() <= seq_len
|
| 493 |
+
x = torch.cat(
|
| 494 |
+
[
|
| 495 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 496 |
+
for u in x
|
| 497 |
+
]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# time embeddings
|
| 501 |
+
if t.dim() == 1: # bs
|
| 502 |
+
t = t.expand(t.size(0), seq_len)
|
| 503 |
+
with torch.amp.autocast("cuda", dtype=torch.float32):
|
| 504 |
+
bt = t.size(0)
|
| 505 |
+
t = t.flatten()
|
| 506 |
+
e = self.time_embedding(
|
| 507 |
+
sinusoidal_embedding_1d(self.freq_dim, t)
|
| 508 |
+
.unflatten(0, (bt, seq_len))
|
| 509 |
+
.float()
|
| 510 |
+
)
|
| 511 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 512 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 513 |
+
|
| 514 |
+
# context
|
| 515 |
+
context_lens = torch.tensor([u.size(0) for u in context], dtype=torch.long)
|
| 516 |
+
context = self.text_embedding(
|
| 517 |
+
torch.stack(
|
| 518 |
+
[
|
| 519 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 520 |
+
for u in context
|
| 521 |
+
]
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# arguments
|
| 526 |
+
kwargs = dict(
|
| 527 |
+
e=e0,
|
| 528 |
+
seq_lens=seq_lens,
|
| 529 |
+
grid_sizes=grid_sizes,
|
| 530 |
+
freqs=self.freqs,
|
| 531 |
+
context=context,
|
| 532 |
+
context_lens=context_lens,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
for block in self.blocks:
|
| 536 |
+
x = block(x, **kwargs)
|
| 537 |
+
|
| 538 |
+
# head
|
| 539 |
+
x = self.head(x, e)
|
| 540 |
+
|
| 541 |
+
# unpatchify
|
| 542 |
+
x = self.unpatchify(x, grid_sizes)
|
| 543 |
+
return [u.float() for u in x]
|
| 544 |
+
|
| 545 |
+
def unpatchify(self, x, grid_sizes):
|
| 546 |
+
r"""
|
| 547 |
+
Reconstruct video tensors from patch embeddings.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
x (List[Tensor]):
|
| 551 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 552 |
+
grid_sizes (Tensor):
|
| 553 |
+
Original spatial-temporal grid dimensions before patching,
|
| 554 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
List[Tensor]:
|
| 558 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
c = self.out_dim
|
| 562 |
+
out = []
|
| 563 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 564 |
+
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
|
| 565 |
+
u = torch.einsum("fhwpqrc->cfphqwr", u)
|
| 566 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 567 |
+
out.append(u)
|
| 568 |
+
return out
|
| 569 |
+
|
| 570 |
+
def init_weights(self):
|
| 571 |
+
r"""
|
| 572 |
+
Initialize model parameters using Xavier initialization.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
# basic init
|
| 576 |
+
for m in self.modules():
|
| 577 |
+
if isinstance(m, nn.Linear):
|
| 578 |
+
nn.init.xavier_uniform_(m.weight)
|
| 579 |
+
if m.bias is not None:
|
| 580 |
+
nn.init.zeros_(m.bias)
|
| 581 |
+
|
| 582 |
+
# init embeddings
|
| 583 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 584 |
+
for m in self.text_embedding.modules():
|
| 585 |
+
if isinstance(m, nn.Linear):
|
| 586 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 587 |
+
for m in self.time_embedding.modules():
|
| 588 |
+
if isinstance(m, nn.Linear):
|
| 589 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 590 |
+
|
| 591 |
+
# init output layer
|
| 592 |
+
nn.init.zeros_(self.head.head.weight)
|
ldf_models/tools/wan_vae_1d.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module uses modified code from Alibaba Wan Team
|
| 2 |
+
# Original source: https://github.com/Wan-Video/Wan2.2
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
# Modified to support 1d features with (B, C, T)
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
CACHE_T = 2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CausalConv1d(nn.Conv1d):
|
| 14 |
+
"""
|
| 15 |
+
Causal 1d convolusion.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, *args, **kwargs):
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
self._padding = (
|
| 21 |
+
2 * self.padding[0],
|
| 22 |
+
0,
|
| 23 |
+
)
|
| 24 |
+
self.padding = (0,)
|
| 25 |
+
|
| 26 |
+
def forward(self, x, cache_x=None):
|
| 27 |
+
padding = list(self._padding)
|
| 28 |
+
if cache_x is not None and self._padding[0] > 0:
|
| 29 |
+
cache_x = cache_x.to(x.device)
|
| 30 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 31 |
+
padding[0] -= cache_x.shape[2]
|
| 32 |
+
x = F.pad(x, padding)
|
| 33 |
+
|
| 34 |
+
return super().forward(x)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RMS_norm(nn.Module):
|
| 38 |
+
def __init__(self, dim, channel_first=True, bias=False):
|
| 39 |
+
super().__init__()
|
| 40 |
+
broadcastable_dims = (1,)
|
| 41 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 42 |
+
|
| 43 |
+
self.channel_first = channel_first
|
| 44 |
+
self.scale = dim**0.5
|
| 45 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 46 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return (
|
| 50 |
+
F.normalize(x, dim=(1 if self.channel_first else -1))
|
| 51 |
+
* self.scale
|
| 52 |
+
* self.gamma
|
| 53 |
+
+ self.bias
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Upsample(nn.Upsample):
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
"""
|
| 60 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 61 |
+
"""
|
| 62 |
+
return super().forward(x.float()).type_as(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Resample(nn.Module):
|
| 66 |
+
def __init__(self, dim, mode):
|
| 67 |
+
assert mode in (
|
| 68 |
+
"upsample1d",
|
| 69 |
+
"downsample1d",
|
| 70 |
+
)
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.dim = dim
|
| 73 |
+
self.mode = mode
|
| 74 |
+
|
| 75 |
+
# layers
|
| 76 |
+
if mode == "upsample1d":
|
| 77 |
+
self.time_conv = CausalConv1d(dim, dim * 2, (3,), padding=(1,))
|
| 78 |
+
elif mode == "downsample1d":
|
| 79 |
+
self.time_conv = CausalConv1d(dim, dim, (3,), stride=(2,), padding=(0,))
|
| 80 |
+
|
| 81 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 82 |
+
b, c, t = x.size()
|
| 83 |
+
if self.mode == "upsample1d":
|
| 84 |
+
if feat_cache is not None:
|
| 85 |
+
idx = feat_idx[0]
|
| 86 |
+
if feat_cache[idx] is None:
|
| 87 |
+
feat_cache[idx] = "Rep"
|
| 88 |
+
feat_idx[0] += 1
|
| 89 |
+
else:
|
| 90 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 91 |
+
if (
|
| 92 |
+
cache_x.shape[2] < 2
|
| 93 |
+
and feat_cache[idx] is not None
|
| 94 |
+
and feat_cache[idx] != "Rep"
|
| 95 |
+
):
|
| 96 |
+
# cache last frame of last two chunk
|
| 97 |
+
cache_x = torch.cat(
|
| 98 |
+
[
|
| 99 |
+
feat_cache[idx][:, :, -1]
|
| 100 |
+
.unsqueeze(2)
|
| 101 |
+
.to(cache_x.device),
|
| 102 |
+
cache_x,
|
| 103 |
+
],
|
| 104 |
+
dim=2,
|
| 105 |
+
)
|
| 106 |
+
if (
|
| 107 |
+
cache_x.shape[2] < 2
|
| 108 |
+
and feat_cache[idx] is not None
|
| 109 |
+
and feat_cache[idx] == "Rep"
|
| 110 |
+
):
|
| 111 |
+
cache_x = torch.cat(
|
| 112 |
+
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
|
| 113 |
+
dim=2,
|
| 114 |
+
)
|
| 115 |
+
if feat_cache[idx] == "Rep":
|
| 116 |
+
x = self.time_conv(x)
|
| 117 |
+
else:
|
| 118 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 119 |
+
feat_cache[idx] = cache_x
|
| 120 |
+
feat_idx[0] += 1
|
| 121 |
+
x = x.reshape(b, 2, c, t)
|
| 122 |
+
x = torch.stack((x[:, 0, :, :], x[:, 1, :, :]), 3)
|
| 123 |
+
x = x.reshape(b, c, t * 2)
|
| 124 |
+
|
| 125 |
+
if self.mode == "downsample1d":
|
| 126 |
+
if feat_cache is not None:
|
| 127 |
+
idx = feat_idx[0]
|
| 128 |
+
if feat_cache[idx] is None:
|
| 129 |
+
feat_cache[idx] = x.clone()
|
| 130 |
+
feat_idx[0] += 1
|
| 131 |
+
else:
|
| 132 |
+
cache_x = x[:, :, -1:].clone()
|
| 133 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:], x], 2))
|
| 134 |
+
feat_cache[idx] = cache_x
|
| 135 |
+
feat_idx[0] += 1
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ResidualBlock(nn.Module):
|
| 140 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.in_dim = in_dim
|
| 143 |
+
self.out_dim = out_dim
|
| 144 |
+
|
| 145 |
+
# layers
|
| 146 |
+
self.residual = nn.Sequential(
|
| 147 |
+
RMS_norm(in_dim),
|
| 148 |
+
nn.SiLU(),
|
| 149 |
+
CausalConv1d(in_dim, out_dim, 3, padding=1),
|
| 150 |
+
RMS_norm(out_dim),
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
nn.Dropout(dropout),
|
| 153 |
+
CausalConv1d(out_dim, out_dim, 3, padding=1),
|
| 154 |
+
)
|
| 155 |
+
self.shortcut = (
|
| 156 |
+
CausalConv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 160 |
+
h = self.shortcut(x)
|
| 161 |
+
for layer in self.residual:
|
| 162 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 163 |
+
idx = feat_idx[0]
|
| 164 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 165 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 166 |
+
# cache last frame of last two chunk
|
| 167 |
+
cache_x = torch.cat(
|
| 168 |
+
[
|
| 169 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 170 |
+
cache_x,
|
| 171 |
+
],
|
| 172 |
+
dim=2,
|
| 173 |
+
)
|
| 174 |
+
x = layer(x, feat_cache[idx])
|
| 175 |
+
feat_cache[idx] = cache_x
|
| 176 |
+
feat_idx[0] += 1
|
| 177 |
+
else:
|
| 178 |
+
x = layer(x)
|
| 179 |
+
return x + h
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class AvgDown1D(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
in_channels,
|
| 186 |
+
out_channels,
|
| 187 |
+
factor_t,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_channels = in_channels
|
| 191 |
+
self.out_channels = out_channels
|
| 192 |
+
self.factor_t = factor_t
|
| 193 |
+
self.factor = self.factor_t
|
| 194 |
+
|
| 195 |
+
assert in_channels * self.factor % out_channels == 0
|
| 196 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 197 |
+
|
| 198 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 200 |
+
pad = (pad_t, 0)
|
| 201 |
+
x = F.pad(x, pad)
|
| 202 |
+
B, C, T = x.shape
|
| 203 |
+
x = x.view(
|
| 204 |
+
B,
|
| 205 |
+
C,
|
| 206 |
+
T // self.factor_t,
|
| 207 |
+
self.factor_t,
|
| 208 |
+
)
|
| 209 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
| 210 |
+
x = x.view(
|
| 211 |
+
B,
|
| 212 |
+
C * self.factor,
|
| 213 |
+
T // self.factor_t,
|
| 214 |
+
)
|
| 215 |
+
x = x.view(
|
| 216 |
+
B,
|
| 217 |
+
self.out_channels,
|
| 218 |
+
self.group_size,
|
| 219 |
+
T // self.factor_t,
|
| 220 |
+
)
|
| 221 |
+
x = x.mean(dim=2)
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class DupUp1D(nn.Module):
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
in_channels: int,
|
| 229 |
+
out_channels: int,
|
| 230 |
+
factor_t,
|
| 231 |
+
):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.in_channels = in_channels
|
| 234 |
+
self.out_channels = out_channels
|
| 235 |
+
|
| 236 |
+
self.factor_t = factor_t
|
| 237 |
+
self.factor = self.factor_t
|
| 238 |
+
|
| 239 |
+
assert out_channels * self.factor % in_channels == 0
|
| 240 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 241 |
+
|
| 242 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 243 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 244 |
+
x = x.view(
|
| 245 |
+
x.size(0),
|
| 246 |
+
self.out_channels,
|
| 247 |
+
self.factor_t,
|
| 248 |
+
x.size(2),
|
| 249 |
+
)
|
| 250 |
+
x = x.permute(0, 1, 3, 2).contiguous()
|
| 251 |
+
x = x.view(
|
| 252 |
+
x.size(0),
|
| 253 |
+
self.out_channels,
|
| 254 |
+
x.size(2) * self.factor_t,
|
| 255 |
+
)
|
| 256 |
+
if first_chunk:
|
| 257 |
+
x = x[
|
| 258 |
+
:,
|
| 259 |
+
:,
|
| 260 |
+
self.factor_t - 1 :,
|
| 261 |
+
]
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class Down_ResidualBlock(nn.Module):
|
| 266 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False):
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
# Shortcut path with downsample
|
| 270 |
+
if temperal_downsample:
|
| 271 |
+
self.avg_shortcut = AvgDown1D(
|
| 272 |
+
in_dim,
|
| 273 |
+
out_dim,
|
| 274 |
+
factor_t=2,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
self.avg_shortcut = None
|
| 278 |
+
|
| 279 |
+
# Main path with residual blocks and downsample
|
| 280 |
+
downsamples = []
|
| 281 |
+
for _ in range(mult):
|
| 282 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 283 |
+
in_dim = out_dim
|
| 284 |
+
|
| 285 |
+
# Add the final downsample block
|
| 286 |
+
if temperal_downsample:
|
| 287 |
+
downsamples.append(Resample(out_dim, mode="downsample1d"))
|
| 288 |
+
|
| 289 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 290 |
+
|
| 291 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 292 |
+
x_copy = x.clone()
|
| 293 |
+
for module in self.downsamples:
|
| 294 |
+
x = module(x, feat_cache, feat_idx)
|
| 295 |
+
if self.avg_shortcut is None:
|
| 296 |
+
return x
|
| 297 |
+
else:
|
| 298 |
+
return x + self.avg_shortcut(x_copy)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class Up_ResidualBlock(nn.Module):
|
| 302 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False):
|
| 303 |
+
super().__init__()
|
| 304 |
+
# Shortcut path with upsample
|
| 305 |
+
if temperal_upsample:
|
| 306 |
+
self.avg_shortcut = DupUp1D(
|
| 307 |
+
in_dim,
|
| 308 |
+
out_dim,
|
| 309 |
+
factor_t=2,
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
self.avg_shortcut = None
|
| 313 |
+
|
| 314 |
+
# Main path with residual blocks and upsample
|
| 315 |
+
upsamples = []
|
| 316 |
+
for _ in range(mult):
|
| 317 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 318 |
+
in_dim = out_dim
|
| 319 |
+
|
| 320 |
+
# Add the final upsample block
|
| 321 |
+
if temperal_upsample:
|
| 322 |
+
upsamples.append(Resample(out_dim, mode="upsample1d"))
|
| 323 |
+
|
| 324 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 325 |
+
|
| 326 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 327 |
+
x_main = x.clone()
|
| 328 |
+
for module in self.upsamples:
|
| 329 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 330 |
+
if self.avg_shortcut is not None:
|
| 331 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 332 |
+
return x_main + x_shortcut
|
| 333 |
+
else:
|
| 334 |
+
return x_main
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class Encoder1d(nn.Module):
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
input_dim,
|
| 341 |
+
dim=128,
|
| 342 |
+
z_dim=4,
|
| 343 |
+
dim_mult=[1, 2, 4, 4],
|
| 344 |
+
num_res_blocks=2,
|
| 345 |
+
temperal_downsample=[True, True, False],
|
| 346 |
+
dropout=0.0,
|
| 347 |
+
):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.dim = dim
|
| 350 |
+
self.z_dim = z_dim
|
| 351 |
+
self.dim_mult = dim_mult
|
| 352 |
+
self.num_res_blocks = num_res_blocks
|
| 353 |
+
self.temperal_downsample = temperal_downsample
|
| 354 |
+
|
| 355 |
+
# dimensions
|
| 356 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 357 |
+
scale = 1.0
|
| 358 |
+
|
| 359 |
+
# init block
|
| 360 |
+
self.conv1 = CausalConv1d(input_dim, dims[0], 3, padding=1)
|
| 361 |
+
|
| 362 |
+
# downsample blocks
|
| 363 |
+
downsamples = []
|
| 364 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 365 |
+
t_down_flag = (
|
| 366 |
+
temperal_downsample[i] if i < len(temperal_downsample) else False
|
| 367 |
+
)
|
| 368 |
+
downsamples.append(
|
| 369 |
+
Down_ResidualBlock(
|
| 370 |
+
in_dim=in_dim,
|
| 371 |
+
out_dim=out_dim,
|
| 372 |
+
dropout=dropout,
|
| 373 |
+
mult=num_res_blocks,
|
| 374 |
+
temperal_downsample=t_down_flag,
|
| 375 |
+
)
|
| 376 |
+
)
|
| 377 |
+
scale /= 2.0
|
| 378 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 379 |
+
|
| 380 |
+
# middle blocks
|
| 381 |
+
self.middle = nn.Sequential(
|
| 382 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 383 |
+
RMS_norm(out_dim),
|
| 384 |
+
CausalConv1d(out_dim, out_dim, 1),
|
| 385 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# # output blocks
|
| 389 |
+
self.head = nn.Sequential(
|
| 390 |
+
RMS_norm(out_dim),
|
| 391 |
+
nn.SiLU(),
|
| 392 |
+
CausalConv1d(out_dim, z_dim, 3, padding=1),
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 396 |
+
if feat_cache is not None:
|
| 397 |
+
idx = feat_idx[0]
|
| 398 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 399 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 400 |
+
cache_x = torch.cat(
|
| 401 |
+
[
|
| 402 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 403 |
+
cache_x,
|
| 404 |
+
],
|
| 405 |
+
dim=2,
|
| 406 |
+
)
|
| 407 |
+
x = self.conv1(x, feat_cache[idx])
|
| 408 |
+
feat_cache[idx] = cache_x
|
| 409 |
+
feat_idx[0] += 1
|
| 410 |
+
else:
|
| 411 |
+
x = self.conv1(x)
|
| 412 |
+
|
| 413 |
+
## downsamples
|
| 414 |
+
for layer in self.downsamples:
|
| 415 |
+
if feat_cache is not None:
|
| 416 |
+
x = layer(x, feat_cache, feat_idx)
|
| 417 |
+
else:
|
| 418 |
+
x = layer(x)
|
| 419 |
+
|
| 420 |
+
## middle
|
| 421 |
+
for layer in self.middle:
|
| 422 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 423 |
+
x = layer(x, feat_cache, feat_idx)
|
| 424 |
+
else:
|
| 425 |
+
x = layer(x)
|
| 426 |
+
|
| 427 |
+
## head
|
| 428 |
+
for layer in self.head:
|
| 429 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 430 |
+
idx = feat_idx[0]
|
| 431 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 432 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 433 |
+
cache_x = torch.cat(
|
| 434 |
+
[
|
| 435 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 436 |
+
cache_x,
|
| 437 |
+
],
|
| 438 |
+
dim=2,
|
| 439 |
+
)
|
| 440 |
+
x = layer(x, feat_cache[idx])
|
| 441 |
+
feat_cache[idx] = cache_x
|
| 442 |
+
feat_idx[0] += 1
|
| 443 |
+
else:
|
| 444 |
+
x = layer(x)
|
| 445 |
+
|
| 446 |
+
return x
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
class Decoder1d(nn.Module):
|
| 450 |
+
def __init__(
|
| 451 |
+
self,
|
| 452 |
+
output_dim,
|
| 453 |
+
dim=128,
|
| 454 |
+
z_dim=4,
|
| 455 |
+
dim_mult=[1, 2, 4, 4],
|
| 456 |
+
num_res_blocks=2,
|
| 457 |
+
temperal_upsample=[False, True, True],
|
| 458 |
+
dropout=0.0,
|
| 459 |
+
):
|
| 460 |
+
super().__init__()
|
| 461 |
+
self.dim = dim
|
| 462 |
+
self.z_dim = z_dim
|
| 463 |
+
self.dim_mult = dim_mult
|
| 464 |
+
self.num_res_blocks = num_res_blocks
|
| 465 |
+
self.temperal_upsample = temperal_upsample
|
| 466 |
+
|
| 467 |
+
# dimensions
|
| 468 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 469 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 470 |
+
# init block
|
| 471 |
+
self.conv1 = CausalConv1d(z_dim, dims[0], 3, padding=1)
|
| 472 |
+
|
| 473 |
+
# middle blocks
|
| 474 |
+
self.middle = nn.Sequential(
|
| 475 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 476 |
+
RMS_norm(dims[0]),
|
| 477 |
+
CausalConv1d(dims[0], dims[0], 1),
|
| 478 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# upsample blocks
|
| 482 |
+
upsamples = []
|
| 483 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 484 |
+
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
| 485 |
+
upsamples.append(
|
| 486 |
+
Up_ResidualBlock(
|
| 487 |
+
in_dim=in_dim,
|
| 488 |
+
out_dim=out_dim,
|
| 489 |
+
dropout=dropout,
|
| 490 |
+
mult=num_res_blocks + 1,
|
| 491 |
+
temperal_upsample=t_up_flag,
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 495 |
+
|
| 496 |
+
# output blocks
|
| 497 |
+
self.head = nn.Sequential(
|
| 498 |
+
RMS_norm(out_dim),
|
| 499 |
+
nn.SiLU(),
|
| 500 |
+
CausalConv1d(out_dim, output_dim, 3, padding=1),
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 504 |
+
if feat_cache is not None:
|
| 505 |
+
idx = feat_idx[0]
|
| 506 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 507 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 508 |
+
cache_x = torch.cat(
|
| 509 |
+
[
|
| 510 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 511 |
+
cache_x,
|
| 512 |
+
],
|
| 513 |
+
dim=2,
|
| 514 |
+
)
|
| 515 |
+
x = self.conv1(x, feat_cache[idx])
|
| 516 |
+
feat_cache[idx] = cache_x
|
| 517 |
+
feat_idx[0] += 1
|
| 518 |
+
else:
|
| 519 |
+
x = self.conv1(x)
|
| 520 |
+
|
| 521 |
+
for layer in self.middle:
|
| 522 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 523 |
+
x = layer(x, feat_cache, feat_idx)
|
| 524 |
+
else:
|
| 525 |
+
x = layer(x)
|
| 526 |
+
|
| 527 |
+
## upsamples
|
| 528 |
+
for layer in self.upsamples:
|
| 529 |
+
if feat_cache is not None:
|
| 530 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 531 |
+
else:
|
| 532 |
+
x = layer(x)
|
| 533 |
+
|
| 534 |
+
## head
|
| 535 |
+
for layer in self.head:
|
| 536 |
+
if isinstance(layer, CausalConv1d) and feat_cache is not None:
|
| 537 |
+
idx = feat_idx[0]
|
| 538 |
+
cache_x = x[:, :, -CACHE_T:].clone()
|
| 539 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 540 |
+
cache_x = torch.cat(
|
| 541 |
+
[
|
| 542 |
+
feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
|
| 543 |
+
cache_x,
|
| 544 |
+
],
|
| 545 |
+
dim=2,
|
| 546 |
+
)
|
| 547 |
+
x = layer(x, feat_cache[idx])
|
| 548 |
+
feat_cache[idx] = cache_x
|
| 549 |
+
feat_idx[0] += 1
|
| 550 |
+
else:
|
| 551 |
+
x = layer(x)
|
| 552 |
+
return x
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def count_conv1d(model):
|
| 556 |
+
count = 0
|
| 557 |
+
for m in model.modules():
|
| 558 |
+
if isinstance(m, CausalConv1d):
|
| 559 |
+
count += 1
|
| 560 |
+
return count
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class WanVAE_(nn.Module):
|
| 564 |
+
def __init__(
|
| 565 |
+
self,
|
| 566 |
+
input_dim,
|
| 567 |
+
dim=160,
|
| 568 |
+
dec_dim=256,
|
| 569 |
+
z_dim=16,
|
| 570 |
+
dim_mult=[1, 2, 4, 4],
|
| 571 |
+
num_res_blocks=1,
|
| 572 |
+
temperal_downsample=[True, True, False],
|
| 573 |
+
dropout=0.0,
|
| 574 |
+
):
|
| 575 |
+
super().__init__()
|
| 576 |
+
self.dim = dim
|
| 577 |
+
self.z_dim = z_dim
|
| 578 |
+
self.dim_mult = dim_mult
|
| 579 |
+
self.num_res_blocks = num_res_blocks
|
| 580 |
+
self.temperal_downsample = temperal_downsample
|
| 581 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 582 |
+
|
| 583 |
+
# modules
|
| 584 |
+
self.encoder = Encoder1d(
|
| 585 |
+
input_dim,
|
| 586 |
+
dim,
|
| 587 |
+
z_dim * 2,
|
| 588 |
+
dim_mult,
|
| 589 |
+
num_res_blocks,
|
| 590 |
+
self.temperal_downsample,
|
| 591 |
+
dropout,
|
| 592 |
+
)
|
| 593 |
+
self.conv1 = CausalConv1d(z_dim * 2, z_dim * 2, 1)
|
| 594 |
+
self.conv2 = CausalConv1d(z_dim, z_dim, 1)
|
| 595 |
+
self.decoder = Decoder1d(
|
| 596 |
+
input_dim,
|
| 597 |
+
dec_dim,
|
| 598 |
+
z_dim,
|
| 599 |
+
dim_mult,
|
| 600 |
+
num_res_blocks,
|
| 601 |
+
self.temperal_upsample,
|
| 602 |
+
dropout,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
def forward(self, x, scale=[0, 1]):
|
| 606 |
+
mu = self.encode(x, scale)
|
| 607 |
+
x_recon = self.decode(mu, scale)
|
| 608 |
+
return x_recon, mu
|
| 609 |
+
|
| 610 |
+
def encode(self, x, scale, return_dist=False):
|
| 611 |
+
self.clear_cache()
|
| 612 |
+
t = x.shape[2]
|
| 613 |
+
iter_ = 1 + (t - 1) // 4
|
| 614 |
+
for i in range(iter_):
|
| 615 |
+
self._enc_conv_idx = [0]
|
| 616 |
+
if i == 0:
|
| 617 |
+
out = self.encoder(
|
| 618 |
+
x[:, :, :1],
|
| 619 |
+
feat_cache=self._enc_feat_map,
|
| 620 |
+
feat_idx=self._enc_conv_idx,
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
out_ = self.encoder(
|
| 624 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
|
| 625 |
+
feat_cache=self._enc_feat_map,
|
| 626 |
+
feat_idx=self._enc_conv_idx,
|
| 627 |
+
)
|
| 628 |
+
out = torch.cat([out, out_], 2)
|
| 629 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 630 |
+
if isinstance(scale[0], torch.Tensor):
|
| 631 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
|
| 632 |
+
1, self.z_dim, 1
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
mu = (mu - scale[0]) * scale[1]
|
| 636 |
+
self.clear_cache()
|
| 637 |
+
if return_dist:
|
| 638 |
+
return mu, log_var
|
| 639 |
+
return mu
|
| 640 |
+
|
| 641 |
+
def decode(self, z, scale):
|
| 642 |
+
self.clear_cache()
|
| 643 |
+
if isinstance(scale[0], torch.Tensor):
|
| 644 |
+
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
|
| 645 |
+
else:
|
| 646 |
+
z = z / scale[1] + scale[0]
|
| 647 |
+
iter_ = z.shape[2]
|
| 648 |
+
x = self.conv2(z)
|
| 649 |
+
for i in range(iter_):
|
| 650 |
+
self._conv_idx = [0]
|
| 651 |
+
if i == 0:
|
| 652 |
+
out = self.decoder(
|
| 653 |
+
x[:, :, i : i + 1],
|
| 654 |
+
feat_cache=self._feat_map,
|
| 655 |
+
feat_idx=self._conv_idx,
|
| 656 |
+
first_chunk=True,
|
| 657 |
+
)
|
| 658 |
+
else:
|
| 659 |
+
out_ = self.decoder(
|
| 660 |
+
x[:, :, i : i + 1],
|
| 661 |
+
feat_cache=self._feat_map,
|
| 662 |
+
feat_idx=self._conv_idx,
|
| 663 |
+
)
|
| 664 |
+
out = torch.cat([out, out_], 2)
|
| 665 |
+
self.clear_cache()
|
| 666 |
+
return out
|
| 667 |
+
|
| 668 |
+
@torch.no_grad()
|
| 669 |
+
def stream_encode(self, x, first_chunk, scale, return_dist=False):
|
| 670 |
+
t = x.shape[2]
|
| 671 |
+
if first_chunk:
|
| 672 |
+
iter_ = 1 + (t - 1) // 4
|
| 673 |
+
else:
|
| 674 |
+
iter_ = t // 4
|
| 675 |
+
for i in range(iter_):
|
| 676 |
+
self._enc_conv_idx = [0]
|
| 677 |
+
if i == 0:
|
| 678 |
+
if first_chunk:
|
| 679 |
+
out = self.encoder(
|
| 680 |
+
x[:, :, :1],
|
| 681 |
+
feat_cache=self._enc_feat_map,
|
| 682 |
+
feat_idx=self._enc_conv_idx,
|
| 683 |
+
)
|
| 684 |
+
else:
|
| 685 |
+
out = self.encoder(
|
| 686 |
+
x[:, :, :4],
|
| 687 |
+
feat_cache=self._enc_feat_map,
|
| 688 |
+
feat_idx=self._enc_conv_idx,
|
| 689 |
+
)
|
| 690 |
+
else:
|
| 691 |
+
if first_chunk:
|
| 692 |
+
out_ = self.encoder(
|
| 693 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
|
| 694 |
+
feat_cache=self._enc_feat_map,
|
| 695 |
+
feat_idx=self._enc_conv_idx,
|
| 696 |
+
)
|
| 697 |
+
else:
|
| 698 |
+
out_ = self.encoder(
|
| 699 |
+
x[:, :, 4 * i : 4 * (i + 1)],
|
| 700 |
+
feat_cache=self._enc_feat_map,
|
| 701 |
+
feat_idx=self._enc_conv_idx,
|
| 702 |
+
)
|
| 703 |
+
out = torch.cat([out, out_], 2)
|
| 704 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 705 |
+
if isinstance(scale[0], torch.Tensor):
|
| 706 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
|
| 707 |
+
1, self.z_dim, 1
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
mu = (mu - scale[0]) * scale[1]
|
| 711 |
+
if return_dist:
|
| 712 |
+
return mu, log_var
|
| 713 |
+
else:
|
| 714 |
+
return mu
|
| 715 |
+
|
| 716 |
+
@torch.no_grad()
|
| 717 |
+
def stream_decode(self, z, first_chunk, scale):
|
| 718 |
+
if isinstance(scale[0], torch.Tensor):
|
| 719 |
+
z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
|
| 720 |
+
else:
|
| 721 |
+
z = z / scale[1] + scale[0]
|
| 722 |
+
iter_ = z.shape[2]
|
| 723 |
+
x = self.conv2(z)
|
| 724 |
+
for i in range(iter_):
|
| 725 |
+
self._conv_idx = [0]
|
| 726 |
+
if i == 0:
|
| 727 |
+
out = self.decoder(
|
| 728 |
+
x[:, :, i : i + 1],
|
| 729 |
+
feat_cache=self._feat_map,
|
| 730 |
+
feat_idx=self._conv_idx,
|
| 731 |
+
first_chunk=first_chunk, # Use the external first_chunk parameter
|
| 732 |
+
)
|
| 733 |
+
else:
|
| 734 |
+
out_ = self.decoder(
|
| 735 |
+
x[:, :, i : i + 1],
|
| 736 |
+
feat_cache=self._feat_map,
|
| 737 |
+
feat_idx=self._conv_idx,
|
| 738 |
+
first_chunk=False, # Explicitly set to False for subsequent time steps within the same chunk
|
| 739 |
+
)
|
| 740 |
+
out = torch.cat([out, out_], 2)
|
| 741 |
+
return out
|
| 742 |
+
|
| 743 |
+
def reparameterize(self, mu, log_var):
|
| 744 |
+
std = torch.exp(0.5 * log_var)
|
| 745 |
+
eps = torch.randn_like(std)
|
| 746 |
+
return eps * std + mu
|
| 747 |
+
|
| 748 |
+
def sample(self, imgs, deterministic=False):
|
| 749 |
+
mu, log_var = self.encode(imgs)
|
| 750 |
+
if deterministic:
|
| 751 |
+
return mu
|
| 752 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 753 |
+
return mu + std * torch.randn_like(std)
|
| 754 |
+
|
| 755 |
+
def clear_cache(self):
|
| 756 |
+
self._conv_num = count_conv1d(self.decoder)
|
| 757 |
+
self._conv_idx = [0]
|
| 758 |
+
self._feat_map = [None] * self._conv_num
|
| 759 |
+
# cache encode
|
| 760 |
+
self._enc_conv_num = count_conv1d(self.encoder)
|
| 761 |
+
self._enc_conv_idx = [0]
|
| 762 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
ldf_models/vae_wan_1d.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .tools.wan_vae_1d import WanVAE_
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VAEWanModel(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
input_dim,
|
| 12 |
+
mean_path=None,
|
| 13 |
+
std_path=None,
|
| 14 |
+
z_dim=256,
|
| 15 |
+
dim=160,
|
| 16 |
+
dec_dim=512,
|
| 17 |
+
num_res_blocks=1,
|
| 18 |
+
dropout=0.0,
|
| 19 |
+
dim_mult=[1, 1, 1],
|
| 20 |
+
temperal_downsample=[True, True],
|
| 21 |
+
vel_window=[0, 0],
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.mean_path = mean_path
|
| 27 |
+
self.std_path = std_path
|
| 28 |
+
self.input_dim = input_dim
|
| 29 |
+
self.z_dim = z_dim
|
| 30 |
+
self.dim = dim
|
| 31 |
+
self.dec_dim = dec_dim
|
| 32 |
+
self.num_res_blocks = num_res_blocks
|
| 33 |
+
self.dropout = dropout
|
| 34 |
+
self.dim_mult = dim_mult
|
| 35 |
+
self.temperal_downsample = temperal_downsample
|
| 36 |
+
self.vel_window = vel_window
|
| 37 |
+
self.RECONS_LOSS = nn.SmoothL1Loss()
|
| 38 |
+
self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
|
| 39 |
+
self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
|
| 40 |
+
self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)
|
| 41 |
+
|
| 42 |
+
if self.mean_path is not None:
|
| 43 |
+
self.register_buffer(
|
| 44 |
+
"mean", torch.from_numpy(np.load(self.mean_path)).float()
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.register_buffer("mean", torch.zeros(input_dim))
|
| 48 |
+
|
| 49 |
+
if self.std_path is not None:
|
| 50 |
+
self.register_buffer(
|
| 51 |
+
"std", torch.from_numpy(np.load(self.std_path)).float()
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
self.register_buffer("std", torch.ones(input_dim))
|
| 55 |
+
|
| 56 |
+
self.model = WanVAE_(
|
| 57 |
+
input_dim=self.input_dim,
|
| 58 |
+
dim=self.dim,
|
| 59 |
+
dec_dim=self.dec_dim,
|
| 60 |
+
z_dim=self.z_dim,
|
| 61 |
+
dim_mult=self.dim_mult,
|
| 62 |
+
num_res_blocks=self.num_res_blocks,
|
| 63 |
+
temperal_downsample=self.temperal_downsample,
|
| 64 |
+
dropout=self.dropout,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
downsample_factor = 1
|
| 68 |
+
for flag in self.temperal_downsample:
|
| 69 |
+
if flag:
|
| 70 |
+
downsample_factor *= 2
|
| 71 |
+
self.downsample_factor = downsample_factor
|
| 72 |
+
|
| 73 |
+
def preprocess(self, x):
|
| 74 |
+
# (bs, T, C) -> (bs, C, T)
|
| 75 |
+
x = x.permute(0, 2, 1)
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def postprocess(self, x):
|
| 79 |
+
# (bs, C, T) -> (bs, T, C)
|
| 80 |
+
x = x.permute(0, 2, 1)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
features = x["feature"]
|
| 85 |
+
feature_length = x["feature_length"]
|
| 86 |
+
features = (features - self.mean) / self.std
|
| 87 |
+
# create mask based on feature_length
|
| 88 |
+
batch_size, seq_len = features.shape[:2]
|
| 89 |
+
mask = torch.zeros(
|
| 90 |
+
batch_size, seq_len, dtype=torch.bool, device=features.device
|
| 91 |
+
)
|
| 92 |
+
for i in range(batch_size):
|
| 93 |
+
mask[i, : feature_length[i]] = True
|
| 94 |
+
|
| 95 |
+
x_in = self.preprocess(features) # (bs, input_dim, T)
|
| 96 |
+
mu, log_var = self.model.encode(
|
| 97 |
+
x_in, scale=[0, 1], return_dist=True
|
| 98 |
+
) # (bs, z_dim, T)
|
| 99 |
+
z = self.model.reparameterize(mu, log_var)
|
| 100 |
+
x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T)
|
| 101 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 102 |
+
|
| 103 |
+
if x_out.size(1) != features.size(1):
|
| 104 |
+
min_len = min(x_out.size(1), features.size(1))
|
| 105 |
+
x_out = x_out[:, :min_len, :]
|
| 106 |
+
features = features[:, :min_len, :]
|
| 107 |
+
mask = mask[:, :min_len]
|
| 108 |
+
|
| 109 |
+
mask_expanded = mask.unsqueeze(-1)
|
| 110 |
+
x_out_masked = x_out * mask_expanded
|
| 111 |
+
features_masked = features * mask_expanded
|
| 112 |
+
loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
|
| 113 |
+
vel_start = self.vel_window[0]
|
| 114 |
+
vel_end = self.vel_window[1]
|
| 115 |
+
loss_vel = self.RECONS_LOSS(
|
| 116 |
+
x_out_masked[..., vel_start:vel_end],
|
| 117 |
+
features_masked[..., vel_start:vel_end],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Compute KL divergence loss
|
| 121 |
+
# KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
| 122 |
+
# log_var = log(sigma^2), so we can use it directly
|
| 123 |
+
|
| 124 |
+
# Build mask for latent space
|
| 125 |
+
T_latent = mu.size(2)
|
| 126 |
+
mask_downsampled = torch.zeros(
|
| 127 |
+
batch_size, T_latent, dtype=torch.bool, device=features.device
|
| 128 |
+
)
|
| 129 |
+
for i in range(batch_size):
|
| 130 |
+
latent_length = (
|
| 131 |
+
feature_length[i] + self.downsample_factor - 1
|
| 132 |
+
) // self.downsample_factor
|
| 133 |
+
mask_downsampled[i, :latent_length] = True
|
| 134 |
+
mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent)
|
| 135 |
+
|
| 136 |
+
# Compute KL loss per element
|
| 137 |
+
kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
|
| 138 |
+
# Apply mask: only compute KL loss for valid timesteps
|
| 139 |
+
kl_masked = kl_per_element * mask_latent
|
| 140 |
+
# Sum over all dimensions and normalize by the number of valid elements
|
| 141 |
+
kl_loss = torch.sum(kl_masked) / (
|
| 142 |
+
torch.sum(mask_downsampled) * mu.size(1)
|
| 143 |
+
) # normalize by valid timesteps * latent_dim
|
| 144 |
+
|
| 145 |
+
# Total loss
|
| 146 |
+
total_loss = (
|
| 147 |
+
self.LAMBDA_FEATURE * loss_recons
|
| 148 |
+
+ self.LAMBDA_VELOCITY * loss_vel
|
| 149 |
+
+ self.LAMBDA_KL * kl_loss
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
loss_dict = {}
|
| 153 |
+
loss_dict["total"] = total_loss
|
| 154 |
+
loss_dict["recons"] = loss_recons
|
| 155 |
+
loss_dict["velocity"] = loss_vel
|
| 156 |
+
loss_dict["kl"] = kl_loss
|
| 157 |
+
|
| 158 |
+
return loss_dict
|
| 159 |
+
|
| 160 |
+
def encode(self, x):
|
| 161 |
+
x = (x - self.mean) / self.std
|
| 162 |
+
x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T)
|
| 163 |
+
mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T)
|
| 164 |
+
mu = self.postprocess(mu) # (bs, T, z_dim)
|
| 165 |
+
return mu
|
| 166 |
+
|
| 167 |
+
def decode(self, mu):
|
| 168 |
+
mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T)
|
| 169 |
+
x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T)
|
| 170 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 171 |
+
x_out = x_out * self.std + self.mean
|
| 172 |
+
return x_out
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def stream_encode(self, x, first_chunk=True):
|
| 176 |
+
x = (x - self.mean) / self.std
|
| 177 |
+
x_in = self.preprocess(x) # (bs, input_dim, T)
|
| 178 |
+
mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
|
| 179 |
+
mu = self.postprocess(mu) # (bs, T, z_dim)
|
| 180 |
+
return mu
|
| 181 |
+
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def stream_decode(self, mu, first_chunk=True):
|
| 184 |
+
mu_in = self.preprocess(mu) # (bs, z_dim, T)
|
| 185 |
+
x_decoder = self.model.stream_decode(
|
| 186 |
+
mu_in, first_chunk=first_chunk, scale=[0, 1]
|
| 187 |
+
)
|
| 188 |
+
x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
|
| 189 |
+
x_out = x_out * self.std + self.mean
|
| 190 |
+
return x_out
|
| 191 |
+
|
| 192 |
+
def clear_cache(self):
|
| 193 |
+
self.model.clear_cache()
|
| 194 |
+
|
| 195 |
+
def generate(self, x):
|
| 196 |
+
features = x["feature"]
|
| 197 |
+
feature_length = x["feature_length"]
|
| 198 |
+
y_hat = self.decode(self.encode(features))
|
| 199 |
+
|
| 200 |
+
y_hat_out = []
|
| 201 |
+
|
| 202 |
+
for i in range(y_hat.shape[0]):
|
| 203 |
+
# cut off the padding and align lengths
|
| 204 |
+
valid_len = (
|
| 205 |
+
feature_length[i] - 1
|
| 206 |
+
) // self.downsample_factor * self.downsample_factor + 1
|
| 207 |
+
# Make sure both have the same length (take minimum)
|
| 208 |
+
y_hat_out.append(y_hat[i, :valid_len, :])
|
| 209 |
+
|
| 210 |
+
out = {}
|
| 211 |
+
out["generated"] = y_hat_out
|
| 212 |
+
return out
|
ldf_utils/__init__.py
ADDED
|
File without changes
|
ldf_utils/initialize.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from importlib import import_module
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from lightning.pytorch.utilities import rank_zero_info
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Config:
|
| 16 |
+
def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
|
| 17 |
+
self.config = OmegaConf.create({})
|
| 18 |
+
|
| 19 |
+
# Load main config if provided
|
| 20 |
+
if config_path:
|
| 21 |
+
self.load_yaml(config_path)
|
| 22 |
+
if override_args:
|
| 23 |
+
self.override_config(override_args)
|
| 24 |
+
|
| 25 |
+
def load_yaml(self, config_path: str):
|
| 26 |
+
"""Load YAML configuration file"""
|
| 27 |
+
loaded_config = OmegaConf.load(config_path)
|
| 28 |
+
self.config = OmegaConf.merge(self.config, loaded_config)
|
| 29 |
+
|
| 30 |
+
def override_config(self, override_args: Dict[str, Any]):
|
| 31 |
+
"""Handle command line override arguments"""
|
| 32 |
+
dotlist = []
|
| 33 |
+
for key, value in override_args.items():
|
| 34 |
+
# Handle values that might be converted types but should be strings for paths
|
| 35 |
+
# The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
|
| 36 |
+
# or splitting logic is wrong.
|
| 37 |
+
# Using OmegaConf's standard from_dotlist approach is safest.
|
| 38 |
+
# It expects "key=value" strings.
|
| 39 |
+
# We need to be careful about value conversion.
|
| 40 |
+
# Our _convert_value handles basic types.
|
| 41 |
+
|
| 42 |
+
val = self._convert_value(value)
|
| 43 |
+
# If val is a string, we keep it as is.
|
| 44 |
+
# OmegaConf.from_dotlist parses the string again if we pass "key=value".
|
| 45 |
+
# But we can construct a config from dict and merge.
|
| 46 |
+
|
| 47 |
+
# If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
|
| 48 |
+
# However, `update` takes a key and value.
|
| 49 |
+
OmegaConf.update(self.config, key, val)
|
| 50 |
+
|
| 51 |
+
def _convert_value(self, value: str) -> Any:
|
| 52 |
+
"""Convert string value to appropriate type"""
|
| 53 |
+
if value.lower() == "true":
|
| 54 |
+
return True
|
| 55 |
+
elif value.lower() == "false":
|
| 56 |
+
return False
|
| 57 |
+
elif value.lower() == "null":
|
| 58 |
+
return None
|
| 59 |
+
try:
|
| 60 |
+
return int(value)
|
| 61 |
+
except ValueError:
|
| 62 |
+
try:
|
| 63 |
+
return float(value)
|
| 64 |
+
except ValueError:
|
| 65 |
+
return value
|
| 66 |
+
|
| 67 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 68 |
+
"""Get configuration value"""
|
| 69 |
+
return OmegaConf.select(self.config, key, default=default)
|
| 70 |
+
|
| 71 |
+
def __getattr__(self, name: str) -> Any:
|
| 72 |
+
"""Support dot notation access"""
|
| 73 |
+
return self.config[name]
|
| 74 |
+
|
| 75 |
+
def __getitem__(self, key: str) -> Any:
|
| 76 |
+
"""Support dictionary-like access"""
|
| 77 |
+
return self.config[key]
|
| 78 |
+
|
| 79 |
+
def export_config(self, path: str):
|
| 80 |
+
"""Export current configuration to file"""
|
| 81 |
+
OmegaConf.save(self.config, path)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def parse_args():
|
| 85 |
+
"""Parse command line arguments"""
|
| 86 |
+
parser = argparse.ArgumentParser()
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--config", type=str, required=True, help="Path to config file"
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--override", type=str, nargs="+", help="Override config values (key=value)"
|
| 92 |
+
)
|
| 93 |
+
return parser.parse_args()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_config(
|
| 97 |
+
config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
|
| 98 |
+
) -> Config:
|
| 99 |
+
"""Load configuration"""
|
| 100 |
+
if config_path is None:
|
| 101 |
+
args = parse_args()
|
| 102 |
+
config_path = args.config
|
| 103 |
+
if args.override:
|
| 104 |
+
override_args = {}
|
| 105 |
+
for override in args.override:
|
| 106 |
+
key, value = override.split("=", 1)
|
| 107 |
+
override_args[key.strip()] = value.strip()
|
| 108 |
+
|
| 109 |
+
return Config(config_path, override_args)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def instantiate(target, cfg=None, hfstyle=False, **init_args):
|
| 113 |
+
module_name, class_name = target.rsplit(".", 1)
|
| 114 |
+
module = import_module(module_name)
|
| 115 |
+
class_ = getattr(module, class_name)
|
| 116 |
+
if cfg is None:
|
| 117 |
+
return class_(**init_args)
|
| 118 |
+
else:
|
| 119 |
+
if hfstyle:
|
| 120 |
+
config_class = class_.config_class
|
| 121 |
+
cfg = config_class(config_obj=cfg)
|
| 122 |
+
return class_(cfg, **init_args)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_function(target):
|
| 126 |
+
module_name, function_name = target.rsplit(".", 1)
|
| 127 |
+
module = import_module(module_name)
|
| 128 |
+
function_ = getattr(module, function_name)
|
| 129 |
+
return function_
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def save_config_and_codes(config, save_dir):
|
| 133 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 134 |
+
sanity_check_dir = os.path.join(save_dir, "sanity_check")
|
| 135 |
+
os.makedirs(sanity_check_dir, exist_ok=True)
|
| 136 |
+
with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
|
| 137 |
+
OmegaConf.save(config.config, f)
|
| 138 |
+
current_dir = Path.cwd()
|
| 139 |
+
exclude_dir = current_dir / "outputs"
|
| 140 |
+
for py_file in current_dir.rglob("*.py"):
|
| 141 |
+
if exclude_dir in py_file.parents:
|
| 142 |
+
continue
|
| 143 |
+
dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
|
| 144 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
shutil.copy(py_file, dest_path)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def print_model_size(model):
|
| 149 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 150 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 151 |
+
rank_zero_info(f"Total parameters: {total_params:,}")
|
| 152 |
+
rank_zero_info(f"Trainable parameters: {trainable_params:,}")
|
| 153 |
+
rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
|
| 157 |
+
"""Compare differences between state_dict and parameters"""
|
| 158 |
+
# Get all keys in state_dict
|
| 159 |
+
state_dict_keys = set(state_dict.keys())
|
| 160 |
+
|
| 161 |
+
# Get all keys in named_parameters
|
| 162 |
+
named_params_keys = set(name for name, _ in named_parameters)
|
| 163 |
+
|
| 164 |
+
# Find keys that only exist in state_dict
|
| 165 |
+
only_in_state_dict = state_dict_keys - named_params_keys
|
| 166 |
+
|
| 167 |
+
# Find keys that only exist in named_parameters
|
| 168 |
+
only_in_named_params = named_params_keys - state_dict_keys
|
| 169 |
+
|
| 170 |
+
# Print results
|
| 171 |
+
if only_in_state_dict:
|
| 172 |
+
print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")
|
| 173 |
+
|
| 174 |
+
if only_in_named_params:
|
| 175 |
+
print(
|
| 176 |
+
f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if not only_in_state_dict and not only_in_named_params:
|
| 180 |
+
print("All parameters match between state_dict and named_parameters")
|
| 181 |
+
|
| 182 |
+
# Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
|
| 183 |
+
named_buffers_keys = set(name for name, _ in named_buffers)
|
| 184 |
+
buffers_only = state_dict_keys - named_params_keys - named_buffers_keys
|
| 185 |
+
|
| 186 |
+
if buffers_only:
|
| 187 |
+
print(
|
| 188 |
+
f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print(f"Total state_dict items: {len(state_dict_keys)}")
|
| 192 |
+
print(f"Total named_parameters: {len(named_params_keys)}")
|
| 193 |
+
print(f"Total named_buffers: {len(named_buffers_keys)}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _resolve_global_rank() -> int:
|
| 197 |
+
"""Resolve the global rank from environment variables."""
|
| 198 |
+
for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
|
| 199 |
+
if key in os.environ:
|
| 200 |
+
try:
|
| 201 |
+
return int(os.environ[key])
|
| 202 |
+
except ValueError:
|
| 203 |
+
continue
|
| 204 |
+
return 0
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
|
| 208 |
+
"""
|
| 209 |
+
Get a synchronized run time across all processes.
|
| 210 |
+
|
| 211 |
+
This function ensures all processes (both in distributed training and multi-process
|
| 212 |
+
scenarios) use the same timestamp for output directories and experiment tracking.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
base_dir: Base directory for output files
|
| 216 |
+
env_key: Environment variable key to cache the run time
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Synchronized timestamp string in format YYYYMMDD_HHMMSS
|
| 220 |
+
"""
|
| 221 |
+
cached = os.environ.get(env_key)
|
| 222 |
+
if cached:
|
| 223 |
+
return cached
|
| 224 |
+
|
| 225 |
+
timestamp_format = "%Y%m%d_%H%M%S"
|
| 226 |
+
|
| 227 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 228 |
+
if torch.distributed.get_rank() == 0:
|
| 229 |
+
run_time = datetime.now().strftime(timestamp_format)
|
| 230 |
+
else:
|
| 231 |
+
run_time = None
|
| 232 |
+
container = [run_time]
|
| 233 |
+
torch.distributed.broadcast_object_list(container, src=0)
|
| 234 |
+
run_time = container[0]
|
| 235 |
+
if run_time is None:
|
| 236 |
+
raise RuntimeError("Failed to synchronize run time across ranks.")
|
| 237 |
+
os.environ[env_key] = run_time
|
| 238 |
+
return run_time
|
| 239 |
+
|
| 240 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 241 |
+
sync_token = (
|
| 242 |
+
os.environ.get("SLURM_JOB_ID")
|
| 243 |
+
or os.environ.get("TORCHELASTIC_RUN_ID")
|
| 244 |
+
or os.environ.get("JOB_ID")
|
| 245 |
+
or "default"
|
| 246 |
+
)
|
| 247 |
+
sync_dir = os.path.join(base_dir, ".run_time_sync")
|
| 248 |
+
os.makedirs(sync_dir, exist_ok=True)
|
| 249 |
+
sync_file = os.path.join(sync_dir, f"{sync_token}.txt")
|
| 250 |
+
|
| 251 |
+
global_rank = _resolve_global_rank()
|
| 252 |
+
if global_rank == 0:
|
| 253 |
+
# Remove the sync file if it exists to avoid stale reads by other ranks
|
| 254 |
+
if os.path.exists(sync_file):
|
| 255 |
+
try:
|
| 256 |
+
os.remove(sync_file)
|
| 257 |
+
except OSError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
run_time = datetime.now().strftime(timestamp_format)
|
| 261 |
+
with open(sync_file, "w", encoding="utf-8") as f:
|
| 262 |
+
f.write(run_time)
|
| 263 |
+
else:
|
| 264 |
+
timeout = time.monotonic() + 1200.0
|
| 265 |
+
while True:
|
| 266 |
+
if os.path.exists(sync_file):
|
| 267 |
+
try:
|
| 268 |
+
with open(sync_file, "r", encoding="utf-8") as f:
|
| 269 |
+
run_time = f.read().strip()
|
| 270 |
+
# Check if the timestamp is fresh (within 60 seconds)
|
| 271 |
+
# This prevents reading a stale timestamp from a previous run
|
| 272 |
+
dt = datetime.strptime(run_time, timestamp_format)
|
| 273 |
+
if abs((datetime.now() - dt).total_seconds()) < 60:
|
| 274 |
+
break
|
| 275 |
+
except (ValueError, OSError):
|
| 276 |
+
# File might be empty or partially written, or format mismatch
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
if time.monotonic() > timeout:
|
| 280 |
+
raise TimeoutError(
|
| 281 |
+
"Timed out waiting for rank 0 to write synchronized timestamp."
|
| 282 |
+
)
|
| 283 |
+
time.sleep(0.1)
|
| 284 |
+
|
| 285 |
+
os.environ[env_key] = run_time
|
| 286 |
+
return run_time
|
ldf_utils/math/__init__.py
ADDED
|
File without changes
|
ldf_utils/math/quaternion.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
| 12 |
+
|
| 13 |
+
_FLOAT_EPS = np.finfo(np.float64).eps
|
| 14 |
+
|
| 15 |
+
# PyTorch-backed implementations
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def qinv(q):
|
| 19 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 20 |
+
mask = torch.ones_like(q)
|
| 21 |
+
mask[..., 1:] = -mask[..., 1:]
|
| 22 |
+
return q * mask
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def qinv_np(q):
|
| 26 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 27 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def qnormalize(q):
|
| 31 |
+
assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
|
| 32 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def qmul(q, r):
|
| 36 |
+
"""
|
| 37 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
| 38 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
| 39 |
+
Returns q*r as a tensor of shape (*, 4).
|
| 40 |
+
"""
|
| 41 |
+
assert q.shape[-1] == 4
|
| 42 |
+
assert r.shape[-1] == 4
|
| 43 |
+
|
| 44 |
+
original_shape = q.shape
|
| 45 |
+
|
| 46 |
+
# Compute outer product
|
| 47 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
| 48 |
+
|
| 49 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
| 50 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
| 51 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
| 52 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
| 53 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def qrot(q, v):
|
| 57 |
+
"""
|
| 58 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
| 59 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
| 60 |
+
where * denotes any number of dimensions.
|
| 61 |
+
Returns a tensor of shape (*, 3).
|
| 62 |
+
"""
|
| 63 |
+
assert q.shape[-1] == 4
|
| 64 |
+
assert v.shape[-1] == 3
|
| 65 |
+
assert q.shape[:-1] == v.shape[:-1]
|
| 66 |
+
|
| 67 |
+
original_shape = list(v.shape)
|
| 68 |
+
# print(q.shape)
|
| 69 |
+
q = q.contiguous().view(-1, 4)
|
| 70 |
+
v = v.contiguous().view(-1, 3)
|
| 71 |
+
|
| 72 |
+
qvec = q[:, 1:]
|
| 73 |
+
uv = torch.cross(qvec, v, dim=1)
|
| 74 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
| 75 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
| 79 |
+
"""
|
| 80 |
+
Convert quaternion(s) q to Euler angles.
|
| 81 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
| 82 |
+
Returns a tensor of shape (*, 3).
|
| 83 |
+
"""
|
| 84 |
+
assert q.shape[-1] == 4
|
| 85 |
+
|
| 86 |
+
original_shape = list(q.shape)
|
| 87 |
+
original_shape[-1] = 3
|
| 88 |
+
q = q.view(-1, 4)
|
| 89 |
+
|
| 90 |
+
q0 = q[:, 0]
|
| 91 |
+
q1 = q[:, 1]
|
| 92 |
+
q2 = q[:, 2]
|
| 93 |
+
q3 = q[:, 3]
|
| 94 |
+
|
| 95 |
+
if order == "xyz":
|
| 96 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 97 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
| 98 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 99 |
+
elif order == "yzx":
|
| 100 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 101 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 102 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
| 103 |
+
elif order == "zxy":
|
| 104 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 105 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 106 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 107 |
+
elif order == "xzy":
|
| 108 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 109 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 110 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
| 111 |
+
elif order == "yxz":
|
| 112 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
| 113 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 114 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
| 115 |
+
elif order == "zyx":
|
| 116 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
| 117 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
| 118 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
| 119 |
+
else:
|
| 120 |
+
raise
|
| 121 |
+
|
| 122 |
+
if deg:
|
| 123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
| 124 |
+
else:
|
| 125 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# Numpy-backed implementations
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def qmul_np(q, r):
|
| 132 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 133 |
+
r = torch.from_numpy(r).contiguous().float()
|
| 134 |
+
return qmul(q, r).numpy()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def qrot_np(q, v):
|
| 138 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 139 |
+
v = torch.from_numpy(v).contiguous().float()
|
| 140 |
+
return qrot(q, v).numpy()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
| 144 |
+
if use_gpu:
|
| 145 |
+
q = torch.from_numpy(q).cuda().float()
|
| 146 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
| 147 |
+
else:
|
| 148 |
+
q = torch.from_numpy(q).contiguous().float()
|
| 149 |
+
return qeuler(q, order, epsilon).numpy()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def qfix(q):
|
| 153 |
+
"""
|
| 154 |
+
Enforce quaternion continuity across the time dimension by selecting
|
| 155 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
| 156 |
+
between two consecutive frames.
|
| 157 |
+
|
| 158 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
| 159 |
+
Returns a tensor of the same shape.
|
| 160 |
+
"""
|
| 161 |
+
assert len(q.shape) == 3
|
| 162 |
+
assert q.shape[-1] == 4
|
| 163 |
+
|
| 164 |
+
result = q.copy()
|
| 165 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
| 166 |
+
mask = dot_products < 0
|
| 167 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
| 168 |
+
result[1:][mask] *= -1
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def euler2quat(e, order, deg=True):
|
| 173 |
+
"""
|
| 174 |
+
Convert Euler angles to quaternions.
|
| 175 |
+
"""
|
| 176 |
+
assert e.shape[-1] == 3
|
| 177 |
+
|
| 178 |
+
original_shape = list(e.shape)
|
| 179 |
+
original_shape[-1] = 4
|
| 180 |
+
|
| 181 |
+
e = e.view(-1, 3)
|
| 182 |
+
|
| 183 |
+
# if euler angles in degrees
|
| 184 |
+
if deg:
|
| 185 |
+
e = e * np.pi / 180.0
|
| 186 |
+
|
| 187 |
+
x = e[:, 0]
|
| 188 |
+
y = e[:, 1]
|
| 189 |
+
z = e[:, 2]
|
| 190 |
+
|
| 191 |
+
rx = torch.stack(
|
| 192 |
+
(torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)),
|
| 193 |
+
dim=1,
|
| 194 |
+
)
|
| 195 |
+
ry = torch.stack(
|
| 196 |
+
(torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)),
|
| 197 |
+
dim=1,
|
| 198 |
+
)
|
| 199 |
+
rz = torch.stack(
|
| 200 |
+
(torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)),
|
| 201 |
+
dim=1,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
result = None
|
| 205 |
+
for coord in order:
|
| 206 |
+
if coord == "x":
|
| 207 |
+
r = rx
|
| 208 |
+
elif coord == "y":
|
| 209 |
+
r = ry
|
| 210 |
+
elif coord == "z":
|
| 211 |
+
r = rz
|
| 212 |
+
else:
|
| 213 |
+
raise
|
| 214 |
+
if result is None:
|
| 215 |
+
result = r
|
| 216 |
+
else:
|
| 217 |
+
result = qmul(result, r)
|
| 218 |
+
|
| 219 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 220 |
+
if order in ["xyz", "yzx", "zxy"]:
|
| 221 |
+
result *= -1
|
| 222 |
+
|
| 223 |
+
return result.view(original_shape)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def expmap_to_quaternion(e):
|
| 227 |
+
"""
|
| 228 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
| 229 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
| 230 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
| 231 |
+
Returns a tensor of shape (*, 4).
|
| 232 |
+
"""
|
| 233 |
+
assert e.shape[-1] == 3
|
| 234 |
+
|
| 235 |
+
original_shape = list(e.shape)
|
| 236 |
+
original_shape[-1] = 4
|
| 237 |
+
e = e.reshape(-1, 3)
|
| 238 |
+
|
| 239 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
| 240 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
| 241 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
| 242 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def euler_to_quaternion(e, order):
|
| 246 |
+
"""
|
| 247 |
+
Convert Euler angles to quaternions.
|
| 248 |
+
"""
|
| 249 |
+
assert e.shape[-1] == 3
|
| 250 |
+
|
| 251 |
+
original_shape = list(e.shape)
|
| 252 |
+
original_shape[-1] = 4
|
| 253 |
+
|
| 254 |
+
e = e.reshape(-1, 3)
|
| 255 |
+
|
| 256 |
+
x = e[:, 0]
|
| 257 |
+
y = e[:, 1]
|
| 258 |
+
z = e[:, 2]
|
| 259 |
+
|
| 260 |
+
rx = np.stack(
|
| 261 |
+
(np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1
|
| 262 |
+
)
|
| 263 |
+
ry = np.stack(
|
| 264 |
+
(np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1
|
| 265 |
+
)
|
| 266 |
+
rz = np.stack(
|
| 267 |
+
(np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
result = None
|
| 271 |
+
for coord in order:
|
| 272 |
+
if coord == "x":
|
| 273 |
+
r = rx
|
| 274 |
+
elif coord == "y":
|
| 275 |
+
r = ry
|
| 276 |
+
elif coord == "z":
|
| 277 |
+
r = rz
|
| 278 |
+
else:
|
| 279 |
+
raise
|
| 280 |
+
if result is None:
|
| 281 |
+
result = r
|
| 282 |
+
else:
|
| 283 |
+
result = qmul_np(result, r)
|
| 284 |
+
|
| 285 |
+
# Reverse antipodal representation to have a non-negative "w"
|
| 286 |
+
if order in ["xyz", "yzx", "zxy"]:
|
| 287 |
+
result *= -1
|
| 288 |
+
|
| 289 |
+
return result.reshape(original_shape)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def quaternion_to_matrix(quaternions):
|
| 293 |
+
"""
|
| 294 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 295 |
+
Args:
|
| 296 |
+
quaternions: quaternions with real part first,
|
| 297 |
+
as tensor of shape (..., 4).
|
| 298 |
+
Returns:
|
| 299 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 300 |
+
"""
|
| 301 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 302 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 303 |
+
|
| 304 |
+
o = torch.stack(
|
| 305 |
+
(
|
| 306 |
+
1 - two_s * (j * j + k * k),
|
| 307 |
+
two_s * (i * j - k * r),
|
| 308 |
+
two_s * (i * k + j * r),
|
| 309 |
+
two_s * (i * j + k * r),
|
| 310 |
+
1 - two_s * (i * i + k * k),
|
| 311 |
+
two_s * (j * k - i * r),
|
| 312 |
+
two_s * (i * k - j * r),
|
| 313 |
+
two_s * (j * k + i * r),
|
| 314 |
+
1 - two_s * (i * i + j * j),
|
| 315 |
+
),
|
| 316 |
+
-1,
|
| 317 |
+
)
|
| 318 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def quaternion_to_matrix_np(quaternions):
|
| 322 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
| 323 |
+
return quaternion_to_matrix(q).numpy()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def quaternion_to_cont6d_np(quaternions):
|
| 327 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
| 328 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
| 329 |
+
return cont_6d
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def quaternion_to_cont6d(quaternions):
|
| 333 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
| 334 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
| 335 |
+
return cont_6d
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def cont6d_to_matrix(cont6d):
|
| 339 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
| 340 |
+
x_raw = cont6d[..., 0:3]
|
| 341 |
+
y_raw = cont6d[..., 3:6]
|
| 342 |
+
|
| 343 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
| 344 |
+
z = torch.cross(x, y_raw, dim=-1)
|
| 345 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
| 346 |
+
|
| 347 |
+
y = torch.cross(z, x, dim=-1)
|
| 348 |
+
|
| 349 |
+
x = x[..., None]
|
| 350 |
+
y = y[..., None]
|
| 351 |
+
z = z[..., None]
|
| 352 |
+
|
| 353 |
+
mat = torch.cat([x, y, z], dim=-1)
|
| 354 |
+
return mat
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def cont6d_to_matrix_np(cont6d):
|
| 358 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
| 359 |
+
return cont6d_to_matrix(q).numpy()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def qpow(q0, t, dtype=torch.float):
|
| 363 |
+
"""q0 : tensor of quaternions
|
| 364 |
+
t: tensor of powers
|
| 365 |
+
"""
|
| 366 |
+
q0 = qnormalize(q0)
|
| 367 |
+
theta0 = torch.acos(q0[..., 0])
|
| 368 |
+
|
| 369 |
+
# if theta0 is close to zero, add epsilon to avoid NaNs
|
| 370 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
| 371 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
| 372 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
| 373 |
+
|
| 374 |
+
if isinstance(t, torch.Tensor):
|
| 375 |
+
q = torch.zeros(t.shape + q0.shape)
|
| 376 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
| 377 |
+
else: # if t is a number
|
| 378 |
+
q = torch.zeros(q0.shape)
|
| 379 |
+
theta = t * theta0
|
| 380 |
+
|
| 381 |
+
q[..., 0] = torch.cos(theta)
|
| 382 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
| 383 |
+
|
| 384 |
+
return q.to(dtype)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def qslerp(q0, q1, t):
|
| 388 |
+
"""
|
| 389 |
+
q0: starting quaternion
|
| 390 |
+
q1: ending quaternion
|
| 391 |
+
t: array of points along the way
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Tensor of Slerps: t.shape + q0.shape
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
q0 = qnormalize(q0)
|
| 398 |
+
q1 = qnormalize(q1)
|
| 399 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
| 400 |
+
|
| 401 |
+
return qmul(
|
| 402 |
+
q_,
|
| 403 |
+
q0.contiguous()
|
| 404 |
+
.view(torch.Size([1] * len(t.shape)) + q0.shape)
|
| 405 |
+
.expand(t.shape + q0.shape)
|
| 406 |
+
.contiguous(),
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def qbetween(v0, v1):
|
| 411 |
+
"""
|
| 412 |
+
find the quaternion used to rotate v0 to v1
|
| 413 |
+
"""
|
| 414 |
+
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
|
| 415 |
+
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
|
| 416 |
+
|
| 417 |
+
v = torch.cross(v0, v1)
|
| 418 |
+
w = torch.sqrt(
|
| 419 |
+
(v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)
|
| 420 |
+
) + (v0 * v1).sum(dim=-1, keepdim=True)
|
| 421 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def qbetween_np(v0, v1):
|
| 425 |
+
"""
|
| 426 |
+
find the quaternion used to rotate v0 to v1
|
| 427 |
+
"""
|
| 428 |
+
assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
|
| 429 |
+
assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
|
| 430 |
+
|
| 431 |
+
v0 = torch.from_numpy(v0).float()
|
| 432 |
+
v1 = torch.from_numpy(v1).float()
|
| 433 |
+
return qbetween(v0, v1).numpy()
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def lerp(p0, p1, t):
|
| 437 |
+
if not isinstance(t, torch.Tensor):
|
| 438 |
+
t = torch.Tensor([t])
|
| 439 |
+
|
| 440 |
+
new_shape = t.shape + p0.shape
|
| 441 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
| 442 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
| 443 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
| 444 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
| 445 |
+
t = t.view(new_view_t).expand(new_shape)
|
| 446 |
+
|
| 447 |
+
return p0 + t * (p1 - p0)
|
ldf_utils/motion_process.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from ldf_utils.math.quaternion import *
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Motion data structure:
|
| 9 |
+
(B: batch size)
|
| 10 |
+
root_rot_velocity (B, seq_len, 1)
|
| 11 |
+
root_linear_velocity (B, seq_len, 2)
|
| 12 |
+
root_y (B, seq_len, 1)
|
| 13 |
+
ric_data (B, seq_len, (joint_num - 1)*3)
|
| 14 |
+
rot_data (B, seq_len, (joint_num - 1)*6)
|
| 15 |
+
local_velocity (B, seq_len, joint_num*3)
|
| 16 |
+
foot contact (B, seq_len, 4)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def recover_root_rot_pos(data):
|
| 21 |
+
# recover root rotation and position
|
| 22 |
+
rot_vel = data[..., 0]
|
| 23 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
| 24 |
+
"""Get Y-axis rotation from rotation velocity"""
|
| 25 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
| 26 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
| 27 |
+
|
| 28 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
| 29 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
| 30 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
| 31 |
+
|
| 32 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
| 33 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
| 34 |
+
"""Add Y-axis rotation to root position"""
|
| 35 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
| 36 |
+
|
| 37 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
| 38 |
+
|
| 39 |
+
r_pos[..., 1] = data[..., 3]
|
| 40 |
+
return r_rot_quat, r_pos
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def recover_joint_positions_263(data: np.ndarray, joints_num) -> np.ndarray:
|
| 44 |
+
"""
|
| 45 |
+
Recovers 3D joint positions from the rotation-invariant local positions (ric_data).
|
| 46 |
+
This is the most direct way to get the skeleton for animation.
|
| 47 |
+
"""
|
| 48 |
+
feature_vec = torch.from_numpy(data).unsqueeze(0).float()
|
| 49 |
+
r_rot_quat, r_pos = recover_root_rot_pos(feature_vec)
|
| 50 |
+
positions = feature_vec[..., 4 : (joints_num - 1) * 3 + 4]
|
| 51 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
| 52 |
+
"""Add Y-axis rotation to local joints"""
|
| 53 |
+
positions = qrot(
|
| 54 |
+
qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions
|
| 55 |
+
)
|
| 56 |
+
"""Add root XZ to joints"""
|
| 57 |
+
positions[..., 0] += r_pos[..., 0:1]
|
| 58 |
+
positions[..., 2] += r_pos[..., 2:3]
|
| 59 |
+
"""Concatenate root and joints"""
|
| 60 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
| 61 |
+
joints_np = positions.squeeze(0).detach().cpu().numpy()
|
| 62 |
+
return joints_np
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class StreamJointRecovery263:
|
| 66 |
+
"""
|
| 67 |
+
Stream version of recover_joint_positions_263 that processes one frame at a time.
|
| 68 |
+
Maintains cumulative state for rotation angles and positions.
|
| 69 |
+
|
| 70 |
+
Key insight: The batch version uses PREVIOUS frame's velocity for the current frame,
|
| 71 |
+
so we need to delay the velocity application by one frame.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
joints_num: Number of joints in the skeleton
|
| 75 |
+
smoothing_alpha: EMA smoothing factor (0.0 to 1.0)
|
| 76 |
+
- 1.0 = no smoothing (default), output follows input exactly
|
| 77 |
+
- 0.0 = infinite smoothing, output never changes
|
| 78 |
+
- Recommended values: 0.3-0.7 for visible smoothing
|
| 79 |
+
- Formula: smoothed = alpha * current + (1 - alpha) * previous
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, joints_num: int, smoothing_alpha: float = 1.0):
|
| 83 |
+
self.joints_num = joints_num
|
| 84 |
+
self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0)
|
| 85 |
+
self.reset()
|
| 86 |
+
|
| 87 |
+
def reset(self):
|
| 88 |
+
"""Reset the accumulated state"""
|
| 89 |
+
self.r_rot_ang_accum = 0.0
|
| 90 |
+
self.r_pos_accum = np.array([0.0, 0.0, 0.0])
|
| 91 |
+
# Store previous frame's velocities for delayed application
|
| 92 |
+
self.prev_rot_vel = 0.0
|
| 93 |
+
self.prev_linear_vel = np.array([0.0, 0.0])
|
| 94 |
+
# Store previous smoothed joints for EMA
|
| 95 |
+
self.prev_smoothed_joints = None
|
| 96 |
+
|
| 97 |
+
def process_frame(self, frame_data: np.ndarray) -> np.ndarray:
|
| 98 |
+
"""
|
| 99 |
+
Process a single frame and return joint positions for that frame.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
frame_data: numpy array of shape (263,) for a single frame
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
joints: numpy array of shape (joints_num, 3) representing joint positions
|
| 106 |
+
"""
|
| 107 |
+
# Convert to torch tensor
|
| 108 |
+
feature_vec = torch.from_numpy(frame_data).float()
|
| 109 |
+
|
| 110 |
+
# Extract current frame's velocities (will be used in NEXT frame)
|
| 111 |
+
curr_rot_vel = feature_vec[0].item()
|
| 112 |
+
curr_linear_vel = feature_vec[1:3].numpy()
|
| 113 |
+
|
| 114 |
+
# Update accumulated rotation angle with PREVIOUS frame's velocity FIRST
|
| 115 |
+
# This matches the batch processing: r_rot_ang[i] uses rot_vel[i-1]
|
| 116 |
+
self.r_rot_ang_accum += self.prev_rot_vel
|
| 117 |
+
|
| 118 |
+
# Calculate current rotation quaternion using updated accumulated angle
|
| 119 |
+
r_rot_quat = torch.zeros(4)
|
| 120 |
+
r_rot_quat[0] = np.cos(self.r_rot_ang_accum)
|
| 121 |
+
r_rot_quat[2] = np.sin(self.r_rot_ang_accum)
|
| 122 |
+
|
| 123 |
+
# Create velocity vector with Y=0 using PREVIOUS frame's velocity
|
| 124 |
+
r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]])
|
| 125 |
+
|
| 126 |
+
# Apply inverse rotation to velocity using CURRENT rotation
|
| 127 |
+
r_vel_torch = torch.from_numpy(r_vel).float()
|
| 128 |
+
r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0))
|
| 129 |
+
r_vel_rotated = r_vel_rotated.squeeze(0).numpy()
|
| 130 |
+
|
| 131 |
+
# Update accumulated position with rotated velocity
|
| 132 |
+
self.r_pos_accum += r_vel_rotated
|
| 133 |
+
|
| 134 |
+
# Get Y position from data
|
| 135 |
+
r_pos = self.r_pos_accum.copy()
|
| 136 |
+
r_pos[1] = feature_vec[3].item()
|
| 137 |
+
|
| 138 |
+
# Extract local joint positions
|
| 139 |
+
positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4]
|
| 140 |
+
positions = positions.view(-1, 3)
|
| 141 |
+
|
| 142 |
+
# Apply inverse rotation to local joints
|
| 143 |
+
r_rot_quat_expanded = (
|
| 144 |
+
qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4)
|
| 145 |
+
)
|
| 146 |
+
positions = qrot(r_rot_quat_expanded, positions)
|
| 147 |
+
|
| 148 |
+
# Add root XZ to joints
|
| 149 |
+
positions[:, 0] += r_pos[0]
|
| 150 |
+
positions[:, 2] += r_pos[2]
|
| 151 |
+
|
| 152 |
+
# Concatenate root and joints
|
| 153 |
+
r_pos_torch = torch.from_numpy(r_pos).float()
|
| 154 |
+
positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0)
|
| 155 |
+
|
| 156 |
+
# Convert to numpy
|
| 157 |
+
joints_np = positions.detach().cpu().numpy()
|
| 158 |
+
|
| 159 |
+
# Apply EMA smoothing if enabled
|
| 160 |
+
if self.smoothing_alpha < 1.0:
|
| 161 |
+
if self.prev_smoothed_joints is None:
|
| 162 |
+
# First frame, no smoothing possible
|
| 163 |
+
self.prev_smoothed_joints = joints_np.copy()
|
| 164 |
+
else:
|
| 165 |
+
# EMA: smoothed = alpha * current + (1 - alpha) * previous
|
| 166 |
+
joints_np = (
|
| 167 |
+
self.smoothing_alpha * joints_np
|
| 168 |
+
+ (1.0 - self.smoothing_alpha) * self.prev_smoothed_joints
|
| 169 |
+
)
|
| 170 |
+
self.prev_smoothed_joints = joints_np.copy()
|
| 171 |
+
|
| 172 |
+
# Store current velocities for next frame
|
| 173 |
+
self.prev_rot_vel = curr_rot_vel
|
| 174 |
+
self.prev_linear_vel = curr_linear_vel
|
| 175 |
+
|
| 176 |
+
return joints_np
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def accumulate_rotations(relative_rotations):
|
| 180 |
+
R_total = [relative_rotations[0]]
|
| 181 |
+
for R_rel in relative_rotations[1:]:
|
| 182 |
+
R_total.append(np.matmul(R_rel, R_total[-1]))
|
| 183 |
+
|
| 184 |
+
return np.array(R_total)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def recover_from_local_position(final_x, njoint):
|
| 188 |
+
nfrm, _ = final_x.shape
|
| 189 |
+
positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(
|
| 190 |
+
nfrm, -1, 3
|
| 191 |
+
) # frames, njoints * 3
|
| 192 |
+
velocities_root_xy_no_heading = final_x[:, :2] # frames, 2
|
| 193 |
+
global_heading_diff_rot = final_x[:, 2:8] # frames, 6
|
| 194 |
+
|
| 195 |
+
# recover global heading
|
| 196 |
+
global_heading_rot = accumulate_rotations(
|
| 197 |
+
rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
|
| 198 |
+
)
|
| 199 |
+
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
|
| 200 |
+
# add global heading to position
|
| 201 |
+
positions_with_heading = np.matmul(
|
| 202 |
+
np.repeat(inv_global_heading_rot[:, None, :, :], njoint, axis=1),
|
| 203 |
+
positions_no_heading[..., None],
|
| 204 |
+
).squeeze(-1)
|
| 205 |
+
|
| 206 |
+
# recover root translation
|
| 207 |
+
# add heading to velocities_root_xy_no_heading
|
| 208 |
+
|
| 209 |
+
velocities_root_xyz_no_heading = np.zeros(
|
| 210 |
+
(
|
| 211 |
+
velocities_root_xy_no_heading.shape[0],
|
| 212 |
+
3,
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
|
| 216 |
+
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
|
| 217 |
+
velocities_root_xyz_no_heading[1:, :] = np.matmul(
|
| 218 |
+
inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
|
| 219 |
+
).squeeze(-1)
|
| 220 |
+
|
| 221 |
+
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
|
| 222 |
+
|
| 223 |
+
# add root translation
|
| 224 |
+
positions_with_heading[:, :, 0] += root_translation[:, 0:1]
|
| 225 |
+
positions_with_heading[:, :, 2] += root_translation[:, 2:]
|
| 226 |
+
|
| 227 |
+
return positions_with_heading
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 232 |
+
b1 = F.normalize(a1, dim=-1)
|
| 233 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 234 |
+
b2 = F.normalize(b2, dim=-1)
|
| 235 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 236 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _copysign(a, b):
|
| 240 |
+
signs_differ = (a < 0) != (b < 0)
|
| 241 |
+
return torch.where(signs_differ, -a, a)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _sqrt_positive_part(x):
|
| 245 |
+
ret = torch.zeros_like(x)
|
| 246 |
+
positive_mask = x > 0
|
| 247 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 248 |
+
return ret
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def matrix_to_quaternion(matrix):
|
| 252 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 253 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
| 254 |
+
m00 = matrix[..., 0, 0]
|
| 255 |
+
m11 = matrix[..., 1, 1]
|
| 256 |
+
m22 = matrix[..., 2, 2]
|
| 257 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
| 258 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
| 259 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
| 260 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
| 261 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
| 262 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
| 263 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
| 264 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def quaternion_to_axis_angle(quaternions):
|
| 268 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 269 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 270 |
+
angles = 2 * half_angles
|
| 271 |
+
eps = 1e-6
|
| 272 |
+
small_angles = angles.abs() < eps
|
| 273 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 274 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 275 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 276 |
+
)
|
| 277 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 278 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 279 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 280 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 281 |
+
)
|
| 282 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def matrix_to_axis_angle(matrix):
|
| 286 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def rotations_matrix_to_smpl85(rotations_matrix, translation):
|
| 290 |
+
nfrm, njoint, _, _ = rotations_matrix.shape
|
| 291 |
+
axis_angle = (
|
| 292 |
+
matrix_to_axis_angle(torch.from_numpy(rotations_matrix))
|
| 293 |
+
.numpy()
|
| 294 |
+
.reshape(nfrm, -1)
|
| 295 |
+
)
|
| 296 |
+
smpl_85 = np.concatenate(
|
| 297 |
+
[axis_angle, np.zeros((nfrm, 6)), translation, np.zeros((nfrm, 10))], axis=-1
|
| 298 |
+
)
|
| 299 |
+
return smpl_85
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def recover_from_local_rotation(final_x, njoint):
|
| 303 |
+
nfrm, _ = final_x.shape
|
| 304 |
+
rotations_matrix = rotation_6d_to_matrix(
|
| 305 |
+
torch.from_numpy(final_x[:, 8 + 6 * njoint : 8 + 12 * njoint]).reshape(
|
| 306 |
+
nfrm, -1, 6
|
| 307 |
+
)
|
| 308 |
+
).numpy()
|
| 309 |
+
global_heading_diff_rot = final_x[:, 2:8]
|
| 310 |
+
velocities_root_xy_no_heading = final_x[:, :2]
|
| 311 |
+
positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(nfrm, -1, 3)
|
| 312 |
+
height = positions_no_heading[:, 0, 1]
|
| 313 |
+
|
| 314 |
+
global_heading_rot = accumulate_rotations(
|
| 315 |
+
rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
|
| 316 |
+
)
|
| 317 |
+
inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
|
| 318 |
+
# recover root rotation
|
| 319 |
+
rotations_matrix[:, 0, ...] = np.matmul(
|
| 320 |
+
inv_global_heading_rot, rotations_matrix[:, 0, ...]
|
| 321 |
+
)
|
| 322 |
+
velocities_root_xyz_no_heading = np.zeros(
|
| 323 |
+
(
|
| 324 |
+
velocities_root_xy_no_heading.shape[0],
|
| 325 |
+
3,
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
|
| 329 |
+
velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
|
| 330 |
+
velocities_root_xyz_no_heading[1:, :] = np.matmul(
|
| 331 |
+
inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
|
| 332 |
+
).squeeze(-1)
|
| 333 |
+
root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
|
| 334 |
+
root_translation[:, 1] = height
|
| 335 |
+
smpl_85 = rotations_matrix_to_smpl85(rotations_matrix, root_translation)
|
| 336 |
+
return smpl_85
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def recover_joint_positions_272(data: np.ndarray, joints_num) -> np.ndarray:
|
| 340 |
+
return recover_from_local_position(data, joints_num)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def convert_motion_to_joints(
|
| 344 |
+
motion_data: np.ndarray,
|
| 345 |
+
dim: int,
|
| 346 |
+
mean: np.ndarray = None,
|
| 347 |
+
std: np.ndarray = None,
|
| 348 |
+
joints_num=22,
|
| 349 |
+
):
|
| 350 |
+
"""
|
| 351 |
+
Convert Kx263 dim or Kx272 dim motion data to Kx22x3 joint positions.
|
| 352 |
+
Args:
|
| 353 |
+
motion_data: numpy array of shape (K, 263) or (K, 272) where K is number of frames
|
| 354 |
+
Returns:
|
| 355 |
+
joints: numpy array of shape (K, 22, 3) representing joint positions
|
| 356 |
+
"""
|
| 357 |
+
if mean is not None and std is not None:
|
| 358 |
+
motion_data = motion_data * std + mean
|
| 359 |
+
if dim == 263:
|
| 360 |
+
recovered_positions = recover_joint_positions_263(motion_data, joints_num)
|
| 361 |
+
elif dim == 272:
|
| 362 |
+
recovered_positions = recover_joint_positions_272(motion_data, joints_num)
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError(f"Unsupported motion data dimension: {dim}")
|
| 365 |
+
return recovered_positions
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3528a345e2795f0b28343896515adc2c14746567896c66620852678ff8d43a79
|
| 3 |
+
size 36753080
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
huggingface_hub>=0.16.0
|
| 5 |
+
safetensors>=0.3.0
|
| 6 |
+
diffusers>=0.20.0
|
| 7 |
+
|
| 8 |
+
# Inference
|
| 9 |
+
lightning>=2.0.0
|
| 10 |
+
ftfy
|
| 11 |
+
|
| 12 |
+
# Configuration
|
| 13 |
+
omegaconf
|
| 14 |
+
|
| 15 |
+
# Utilities
|
| 16 |
+
numpy
|
| 17 |
+
|
| 18 |
+
# Note: flash-attn is required but needs special installation
|
| 19 |
+
# See README.md for installation instructions
|
vae.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a40164154c476309ff952a4b7563750b7e76fbdd8d263ec261ad877cf452e7b
|
| 3 |
+
size 70027220
|