herrscher0 commited on
Commit
e86746e
·
0 Parent(s):

Initial commit: FloodDiffusionTiny - Tiny text-to-motion model with UMT5-Base

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ldf_deps/t5_umt5-xxl-enc-bf16/google/umt5-xxl/tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+
8
+ # Virtual environments
9
+ venv/
10
+ env/
11
+ ENV/
12
+ .venv
13
+
14
+ # PyTorch/Model cache
15
+ *.pth~
16
+ *.safetensors~
17
+ checkpoint/
18
+ checkpoints/
19
+
20
+ # Hugging Face cache
21
+ .cache/
22
+ huggingface_cache/
23
+
24
+ # Generated outputs
25
+ outputs/
26
+ generated_motions/
27
+ *.npy
28
+ *.pkl
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+ *~
36
+
37
+ # OS
38
+ .DS_Store
39
+ Thumbs.db
40
+
41
+ # Jupyter
42
+ .ipynb_checkpoints/
43
+ *.ipynb
44
+
45
+ # Logs
46
+ *.log
47
+ logs/
48
+ wandb/
49
+
50
+ # Test outputs
51
+ test_output/
52
+ test_results/
53
+ tmp/
54
+
README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - text-to-motion
5
+ - motion-generation
6
+ - diffusion-forcing
7
+ - humanml3d
8
+ - computer-animation
9
+ library_name: transformers
10
+ pipeline_tag: other
11
+ ---
12
+
13
+ # FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation
14
+
15
+ <div align="center">
16
+
17
+ **A TINY version of the original FloodDiffusion**
18
+
19
+ [Paper](https://arxiv.org/abs/2512.03520) | [Github](https://github.com/ShandaAI/FloodDiffusion) | [Project Page](https://shandaai.github.io/FloodDiffusion/)
20
+
21
+ </div>
22
+
23
+ ## Installation
24
+
25
+ ### Prerequisites
26
+
27
+ - Python 3.8+
28
+ - CUDA-capable GPU with 16GB+ VRAM (recommended)
29
+ - 16GB+ system RAM
30
+
31
+ ### Dependencies
32
+
33
+ **Step 1: Install basic dependencies**
34
+
35
+ ```bash
36
+ pip install torch transformers huggingface_hub
37
+ pip install lightning diffusers omegaconf ftfy numpy
38
+ ```
39
+
40
+ **Step 2: Install Flash Attention (Required)**
41
+
42
+ Flash attention requires CUDA and may need compilation. Choose the appropriate method:
43
+
44
+ ```bash
45
+ pip install flash-attn --no-build-isolation
46
+ ```
47
+
48
+ **Note:** Flash attention is **required** for this model. If installation fails, please refer to the [official flash-attention installation guide](https://github.com/Dao-AILab/flash-attention#installation-and-features).
49
+
50
+ ## Quick Start
51
+
52
+ ### Basic Usage
53
+
54
+ ```python
55
+ from transformers import AutoModel
56
+
57
+ # Load model
58
+ model = AutoModel.from_pretrained(
59
+ "ShandaAI/FloodDiffusionTiny",
60
+ trust_remote_code=True
61
+ )
62
+
63
+ # Generate motion from text (263-dim HumanML3D features)
64
+ motion = model("a person walking forward", length=60)
65
+ print(f"Generated motion: {motion.shape}") # (~240, 263)
66
+
67
+ # Generate motion as joint coordinates (22 joints × 3 coords) with ema (alpha: 0.0-1.0)
68
+ motion_joints = model("a person walking forward", length=60, output_joints=True, smoothing_alpha=0.5)
69
+ print(f"Generated joints: {motion_joints.shape}") # (~240, 22, 3)
70
+ ```
71
+
72
+ ### Batch Generation
73
+
74
+ ```python
75
+ # Generate multiple motions efficiently
76
+ texts = [
77
+ "a person walking forward",
78
+ "a person running quickly",
79
+ "a person jumping up and down"
80
+ ]
81
+ lengths = [60, 50, 40] # Different lengths for each motion
82
+
83
+ motions = model(texts, length=lengths)
84
+
85
+ for i, motion in enumerate(motions):
86
+ print(f"Motion {i}: {motion.shape}")
87
+ ```
88
+
89
+ ### Multi-Text Motion Transitions
90
+
91
+ ```python
92
+ # Generate a motion sequence with smooth transitions between actions
93
+ motion = model(
94
+ text=[["walk forward", "turn around", "run back"]],
95
+ length=[120],
96
+ text_end=[[40, 80, 120]] # Transition points in latent tokens
97
+ )
98
+
99
+ # Output: ~480 frames showing all three actions smoothly connected
100
+ print(f"Transition motion: {motion[0].shape}")
101
+ ```
102
+
103
+ ## API Reference
104
+
105
+ ### `model(text, length=60, text_end=None, num_denoise_steps=None, output_joints=False, smoothing_alpha=1.0)`
106
+
107
+ Generate motion sequences from text descriptions.
108
+
109
+ **Parameters:**
110
+
111
+ - **text** (`str`, `List[str]`, or `List[List[str]]`): Text description(s)
112
+ - Single string: Generate one motion
113
+ - List of strings: Batch generation
114
+ - Nested list: Multiple text prompts per motion (for transitions)
115
+
116
+ - **length** (`int` or `List[int]`, default=60): Number of latent tokens to generate
117
+ - Output frames ≈ `length × 4` (due to VAE upsampling)
118
+ - Example: `length=60` → ~240 frames (~12 seconds at 20 FPS)
119
+
120
+ - **text_end** (`List[int]` or `List[List[int]]`, optional): Latent token positions for text transitions
121
+ - Only used when `text` is a nested list
122
+ - Specifies when to switch between different text descriptions
123
+ - **IMPORTANT**: Must have the same length as the corresponding text list
124
+ - Example: `text=[["walk", "turn", "sit"]]` requires `text_end=[[20, 40, 60]]` (3 endpoints for 3 texts)
125
+ - Must be in ascending order
126
+
127
+ - **num_denoise_steps** (`int`, optional): Number of denoising iterations
128
+ - Higher values produce better quality but slower generation
129
+ - Recommended range: 10-50
130
+
131
+ - **output_joints** (`bool`, default=False): Output format selector
132
+ - `False`: Returns 263-dimensional HumanML3D features
133
+ - `True`: Returns 22×3 joint coordinates for direct visualization
134
+
135
+ - **smoothing_alpha** (`float`, default=1.0): EMA smoothing factor for joint positions (only used when `output_joints=True`)
136
+ - `1.0`: No smoothing (default)
137
+ - `0.5`: Medium smoothing (recommended for smoother animations)
138
+ - `0.0`: Maximum smoothing
139
+ - Range: 0.0 to 1.0
140
+
141
+ **Returns:**
142
+ - Single motion:
143
+ - `output_joints=False`: `numpy.ndarray` of shape `(frames, 263)`
144
+ - `output_joints=True`: `numpy.ndarray` of shape `(frames, 22, 3)`
145
+ - Batch: `List[numpy.ndarray]` with shapes as above
146
+
147
+ **Example:**
148
+ ```python
149
+ # Single generation (263-dim features)
150
+ motion = model("walk forward", length=60) # Returns (240, 263)
151
+
152
+ # Single generation (joint coordinates)
153
+ joints = model("walk forward", length=60, output_joints=True) # Returns (240, 22, 3)
154
+
155
+ # Batch generation
156
+ motions = model(["walk", "run"], length=[60, 50]) # Returns list of 2 arrays
157
+
158
+ # Multi-text transitions
159
+ motion = model(
160
+ [["walk", "turn"]],
161
+ length=[60],
162
+ text_end=[[30, 60]]
163
+ ) # Returns list with 1 array of shape (240, 263)
164
+ ```
165
+
166
+ ## Citation
167
+
168
+ If you use this model in your research, please cite:
169
+
170
+ ```bibtex
171
+ @article{cai2025flooddiffusion,
172
+ title={FloodDiffusion: Tailored Diffusion Forcing for Streaming Motion Generation},
173
+ author={Yiyi Cai, Yuhan Wu, Kunhang Li, You Zhou, Bo Zheng, Haiyang Liu},
174
+ journal={arXiv preprint arXiv:2512.03520},
175
+ year={2025}
176
+ }
177
+ ```
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FloodDiffusion - Text-to-Motion Generation
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+
7
+ model = AutoModel.from_pretrained("your-username/FloodDiffusion", trust_remote_code=True)
8
+ motion = model("a person walking forward", length=60)
9
+ """
10
+
11
+ __version__ = "1.0.0"
config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["LDFModel"],
3
+ "model_type": "ldf_motion",
4
+ "auto_map": {
5
+ "AutoModel": "hf_pipeline.LDFModel",
6
+ "AutoConfig": "hf_pipeline.LDFConfig"
7
+ },
8
+ "torch_dtype": "float32",
9
+ "transformers_version": "4.30.0",
10
+ "license": "mit"
11
+ }
generate_ldf.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ from lightning import seed_everything
6
+ from safetensors.torch import load_file as load_safetensors
7
+
8
+ from ldf_utils.initialize import compare_statedict_and_parameters, instantiate, load_config
9
+
10
+ # Set tokenizers parallelism to false to avoid warnings in multiprocessing
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+
13
+
14
+ def load_model_from_config():
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ torch.set_float32_matmul_precision("high")
17
+ cfg = load_config()
18
+ seed_everything(cfg.seed)
19
+
20
+ # Get the directory containing the config file
21
+ # Try to find config directory from sys.argv or use current directory
22
+ if '--config' in sys.argv:
23
+ config_idx = sys.argv.index('--config') + 1
24
+ config_dir = os.path.dirname(os.path.abspath(sys.argv[config_idx]))
25
+ else:
26
+ config_dir = os.getcwd()
27
+
28
+ vae = instantiate(
29
+ target=cfg.test_vae.target,
30
+ cfg=None,
31
+ hfstyle=False,
32
+ **cfg.test_vae.params,
33
+ )
34
+
35
+ # Handle relative paths
36
+ vae_path = cfg.test_vae_ckpt
37
+ if not os.path.isabs(vae_path):
38
+ vae_path = os.path.join(config_dir, vae_path)
39
+
40
+ # Load from safetensors (already contains EMA weights)
41
+ vae_state_dict = load_safetensors(vae_path)
42
+ vae.load_state_dict(vae_state_dict, strict=True)
43
+ print(f"Loaded VAE model from {vae_path}")
44
+
45
+ compare_statedict_and_parameters(
46
+ state_dict=vae.state_dict(),
47
+ named_parameters=vae.named_parameters(),
48
+ named_buffers=vae.named_buffers(),
49
+ )
50
+ vae.to(device)
51
+ vae.eval()
52
+
53
+ # Model - fix relative paths in model params
54
+ model_params = dict(cfg.model.params)
55
+ # Convert relative paths to absolute paths
56
+ if 'checkpoint_path' in model_params and model_params['checkpoint_path']:
57
+ if not os.path.isabs(model_params['checkpoint_path']):
58
+ model_params['checkpoint_path'] = os.path.join(config_dir, model_params['checkpoint_path'])
59
+ if 'tokenizer_path' in model_params and model_params['tokenizer_path']:
60
+ if not os.path.isabs(model_params['tokenizer_path']):
61
+ model_params['tokenizer_path'] = os.path.join(config_dir, model_params['tokenizer_path'])
62
+
63
+ model = instantiate(
64
+ target=cfg.model.target, cfg=None, hfstyle=False, **model_params
65
+ )
66
+
67
+ # Handle relative paths
68
+ model_path = cfg.test_ckpt
69
+ if not os.path.isabs(model_path):
70
+ model_path = os.path.join(config_dir, model_path)
71
+
72
+ # Load from safetensors (already contains EMA weights)
73
+ model_state_dict = load_safetensors(model_path)
74
+ model.load_state_dict(model_state_dict, strict=True)
75
+ print(f"Loaded model from {model_path}")
76
+
77
+ compare_statedict_and_parameters(
78
+ state_dict=model.state_dict(),
79
+ named_parameters=model.named_parameters(),
80
+ named_buffers=model.named_buffers(),
81
+ )
82
+ model.to(device)
83
+ model.eval()
84
+
85
+ return vae, model
86
+
87
+
88
+ @torch.inference_mode()
89
+ def generate_feature_stream(
90
+ model, feature_length, text, feature_text_end=None, num_denoise_steps=None
91
+ ):
92
+ """
93
+ Streaming interface for feature generation
94
+ Args:
95
+ model: Loaded model
96
+ feature_length: List[int], generation length for each sample
97
+ text: List[str] or List[List[str]], text prompts
98
+ feature_text_end: List[List[int]], time points where text ends (if text is list of list)
99
+ num_denoise_steps: Number of denoising steps
100
+ Yields:
101
+ dict: Contains "generated" (current generated feature segment)
102
+ """
103
+
104
+ # Construct input dict x
105
+ # stream_generate needs x to contain "feature_length", "text", "feature_text_end" (if text is list of list)
106
+ x = {"feature_length": torch.tensor(feature_length), "text": text}
107
+
108
+ if feature_text_end is not None:
109
+ x["feature_text_end"] = feature_text_end
110
+
111
+ # Call model's stream_generate
112
+ # Note: stream_generate is a generator
113
+ generator = model.stream_generate(x, num_denoise_steps=num_denoise_steps)
114
+
115
+ for step_output in generator:
116
+ # step_output is already a dict with "generated" key
117
+ yield step_output
118
+
119
+
120
+ if __name__ == "__main__":
121
+ import argparse
122
+
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--config", type=str, required=True, help="Path to config")
125
+ parser.add_argument(
126
+ "--text", type=str, default="a person walks forward", help="Text prompt"
127
+ )
128
+ parser.add_argument("--length", type=int, default=120, help="Motion length")
129
+ parser.add_argument(
130
+ "--output", type=str, default="output.mp4", help="Output video path"
131
+ )
132
+ parser.add_argument(
133
+ "--num_denoise_steps", type=int, default=None, help="Number of denoising steps"
134
+ )
135
+ args = parser.parse_args()
136
+
137
+ print("Loading model...")
138
+ vae, model = load_model_from_config()
139
+
hf_pipeline.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LDF Model for Hugging Face Hub
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+
7
+ model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
8
+ motion = model("a person walking forward", length=60)
9
+ """
10
+
11
+ import torch
12
+ from transformers import PretrainedConfig, PreTrainedModel
13
+ from typing import Union, List, Optional
14
+ import os
15
+ import sys
16
+
17
+
18
+ class LDFConfig(PretrainedConfig):
19
+ """Configuration for LDF Motion Generation Model"""
20
+ model_type = "ldf_motion"
21
+
22
+ def __init__(
23
+ self,
24
+ input_dim=4,
25
+ output_dim=263,
26
+ **kwargs
27
+ ):
28
+ super().__init__(**kwargs)
29
+ self.input_dim = input_dim
30
+ self.output_dim = output_dim
31
+
32
+
33
+ class LDFModel(PreTrainedModel):
34
+ """
35
+ LDF Motion Generation Model
36
+
37
+ This model generates motion sequences from text descriptions using Latent Diffusion Forcing.
38
+
39
+ Example:
40
+ >>> from transformers import AutoModel
41
+ >>> model = AutoModel.from_pretrained("ShandaAI/FloodDiffusion", trust_remote_code=True)
42
+ >>> motion = model("a person walking forward", length=60)
43
+ >>> print(motion.shape) # (~240, 263)
44
+ """
45
+
46
+ config_class = LDFConfig
47
+
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+ self.config = config
51
+
52
+ # Will be loaded in from_pretrained
53
+ self.ldf_model = None
54
+ self.vae = None
55
+ self.model_dir = None # Store model directory for later use
56
+
57
+ def _load_models(self):
58
+ """Load the actual LDF and VAE models"""
59
+ if self.ldf_model is not None:
60
+ return # Already loaded
61
+
62
+ # Get the model directory - should be set by from_pretrained
63
+ if hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path):
64
+ model_dir = self.name_or_path
65
+ else:
66
+ raise RuntimeError(
67
+ "Model directory not found. Please use from_pretrained() to load the model."
68
+ )
69
+
70
+ # Save model_dir for later use (e.g., in output_joints conversion)
71
+ self.model_dir = model_dir
72
+
73
+ # Add model_dir to sys.path for imports
74
+ if model_dir not in sys.path:
75
+ sys.path.insert(0, model_dir)
76
+
77
+ # Use dynamic import to avoid HF's static import checker
78
+ import importlib
79
+ generate_ldf = importlib.import_module('generate_ldf')
80
+ load_model_from_config = generate_ldf.load_model_from_config
81
+
82
+ config_path = os.path.join(model_dir, "ldf.yaml")
83
+ old_argv = sys.argv
84
+ sys.argv = ['model', '--config', config_path]
85
+
86
+ try:
87
+ self.vae, self.ldf_model = load_model_from_config()
88
+
89
+ # Move to correct device
90
+ device = next(self.parameters()).device if list(self.parameters()) else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+ self.ldf_model = self.ldf_model.to(device)
92
+ self.vae = self.vae.to(device)
93
+ finally:
94
+ sys.argv = old_argv
95
+
96
+ @classmethod
97
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
98
+ """
99
+ Load pretrained model
100
+
101
+ Args:
102
+ pretrained_model_name_or_path: Model name or path
103
+ trust_remote_code: Must be True to load this custom model
104
+ **kwargs: Additional arguments
105
+
106
+ Returns:
107
+ LDFModel instance
108
+ """
109
+ # Check trust_remote_code
110
+ if not kwargs.get('trust_remote_code', False):
111
+ raise ValueError(
112
+ "Loading this model requires trust_remote_code=True. "
113
+ "Usage: AutoModel.from_pretrained(..., trust_remote_code=True)"
114
+ )
115
+
116
+ # Download if needed
117
+ if not os.path.exists(pretrained_model_name_or_path):
118
+ from huggingface_hub import snapshot_download
119
+ model_path = snapshot_download(repo_id=pretrained_model_name_or_path)
120
+ else:
121
+ model_path = pretrained_model_name_or_path
122
+
123
+ # Load config
124
+ config = LDFConfig.from_pretrained(model_path)
125
+
126
+ # Create model
127
+ model = cls(config)
128
+ model.name_or_path = model_path
129
+
130
+ # Load the actual models
131
+ model._load_models()
132
+
133
+ return model
134
+
135
+ def forward(
136
+ self,
137
+ text: Union[str, List[str], List[List[str]]],
138
+ length: Union[int, List[int]] = 60,
139
+ text_end: Optional[Union[List[int], List[List[int]]]] = None,
140
+ num_denoise_steps: Optional[int] = None,
141
+ **kwargs
142
+ ):
143
+ """
144
+ Generate motion from text
145
+
146
+ Args:
147
+ text: Text description(s)
148
+ length: Number of latent tokens (output frames ≈ length × 4)
149
+ text_end: Transition points for multi-text
150
+ num_denoise_steps: Number of denoising steps
151
+
152
+ Returns:
153
+ Generated motion sequence(s)
154
+ """
155
+ return self.__call__(text, length, text_end, num_denoise_steps)
156
+
157
+ @torch.no_grad()
158
+ def __call__(
159
+ self,
160
+ text: Union[str, List[str], List[List[str]]],
161
+ length: Union[int, List[int]] = 60,
162
+ text_end: Optional[Union[List[int], List[List[int]]]] = None,
163
+ num_denoise_steps: Optional[int] = None,
164
+ output_joints: bool = False,
165
+ smoothing_alpha: float = 1.0
166
+ ):
167
+ """
168
+ Generate motion sequences
169
+
170
+ Args:
171
+ text: Text description
172
+ - Single string: "walk" -> single sample
173
+ - String list: ["walk", "run"] -> batch
174
+ - Nested list: [["walk", "turn"], ["run", "jump"]] -> multi-text per sample
175
+ length: Number of latent tokens (frames ≈ length × 4)
176
+ text_end: Token positions for text switching
177
+ num_denoise_steps: Number of denoising steps
178
+ output_joints: If True, output 22×3 joint coordinates; if False (default), output 263-dim HumanML3D features
179
+ smoothing_alpha: EMA smoothing factor for joint positions (0.0-1.0, default=1.0 no smoothing)
180
+ - Only used when output_joints=True
181
+ - Recommended: 0.5 for smoother animations
182
+
183
+ Returns:
184
+ numpy.ndarray or list of arrays
185
+ - If output_joints=False: shape (frames, 263)
186
+ - If output_joints=True: shape (frames, 22, 3)
187
+ """
188
+ # Ensure models are loaded
189
+ self._load_models()
190
+
191
+ # Normalize inputs
192
+ is_single = not isinstance(length, list)
193
+ if is_single:
194
+ text_batch = [text]
195
+ length_batch = [length]
196
+ text_end_batch = [text_end] if text_end is not None else None
197
+ else:
198
+ text_batch = text
199
+ length_batch = length
200
+ text_end_batch = text_end
201
+
202
+ # Validate text_end alignment with text
203
+ if text_end_batch is not None:
204
+ for i, (txt, te) in enumerate(zip(text_batch, text_end_batch)):
205
+ if isinstance(txt, list) and te is not None:
206
+ if len(txt) != len(te):
207
+ raise ValueError(
208
+ f"Batch {i}: text has {len(txt)} segments but text_end has {len(te)} endpoints. "
209
+ f"They must match! text={txt}, text_end={te}"
210
+ )
211
+
212
+ batch_size = len(text_batch)
213
+
214
+ # Construct input dict for model
215
+ x = {"feature_length": torch.tensor(length_batch), "text": text_batch}
216
+ if text_end_batch is not None:
217
+ x["feature_text_end"] = text_end_batch
218
+
219
+ # Non-streaming generate (following generate_ldf.py 125-139)
220
+ output = self.ldf_model.generate(x, num_denoise_steps=num_denoise_steps)
221
+ generated_batch = output["generated"]
222
+
223
+ # Decode with VAE and optionally convert to joints
224
+ decoded_results = []
225
+ joints_results = [] if output_joints else None
226
+
227
+ # Import motion processing module once if needed
228
+ if output_joints:
229
+ import importlib.util
230
+ import numpy as np
231
+ utils_spec = importlib.util.spec_from_file_location(
232
+ "motion_process",
233
+ os.path.join(self.model_dir, "ldf_utils", "motion_process.py")
234
+ )
235
+ motion_process_module = importlib.util.module_from_spec(utils_spec)
236
+ utils_spec.loader.exec_module(motion_process_module)
237
+
238
+ for i, generated in enumerate(generated_batch):
239
+ if generated is not None and torch.is_tensor(generated):
240
+ # Decode with VAE (following generate_ldf.py line 130)
241
+ decoded_g = self.vae.decode(generated[None, :])[0]
242
+
243
+ if output_joints:
244
+ # Convert to joints using StreamJointRecovery263 with smoothing
245
+ # Create a new recovery instance for each sample to maintain independent state
246
+ decoded_np = decoded_g.cpu().numpy()
247
+ recovery = motion_process_module.StreamJointRecovery263(
248
+ joints_num=22, smoothing_alpha=smoothing_alpha
249
+ )
250
+ joints = [recovery.process_frame(frame) for frame in decoded_np]
251
+ joints = np.array(joints)
252
+ joints_results.append(joints)
253
+ else:
254
+ decoded_results.append(decoded_g.cpu().numpy())
255
+ else:
256
+ if output_joints:
257
+ joints_results.append(None)
258
+ else:
259
+ decoded_results.append(None)
260
+
261
+ # Return results
262
+ if output_joints:
263
+ return joints_results[0] if is_single else joints_results
264
+ else:
265
+ return decoded_results[0] if is_single else decoded_results
266
+
267
+ def generate(self, *args, **kwargs):
268
+ """Alias for __call__ to match transformers API"""
269
+ return self.__call__(*args, **kwargs)
270
+
271
+
272
+ # For backwards compatibility
273
+ LDFPipeline = LDFModel
274
+
275
+
276
+ # Register with AutoModel
277
+ try:
278
+ from transformers import AutoModel, AutoConfig
279
+ AutoConfig.register("ldf_motion", LDFConfig)
280
+ AutoModel.register(LDFConfig, LDFModel)
281
+ except:
282
+ pass
ldf.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: ldf
2
+ seed: 1234
3
+ debug: false
4
+ train: false
5
+
6
+ save_dir: ./outputs
7
+ resume_ckpt: null
8
+ test_ckpt: "model.safetensors"
9
+ test_vae_ckpt: "vae.safetensors"
10
+
11
+ test_vae:
12
+ target: ldf_models.vae_wan_1d.VAEWanModel
13
+ ema_decay: 0.99
14
+ params:
15
+ input_dim: 263
16
+ z_dim: 4
17
+
18
+ test_setting:
19
+ render: false
20
+ simple: true
21
+ recover_dim: 263
22
+
23
+ val_repeat: 1
24
+
25
+ model:
26
+ target: ldf_models.diffusion_forcing_wan_tiny.DiffForcingWanModel
27
+ ema_decay: 0.99
28
+ params:
29
+ model_name: "google/umt5-base"
30
+ input_dim: 4
31
+ noise_steps: 10
32
+ hidden_dim: 256
33
+ ffn_dim: 1024
34
+ freq_dim: 64
35
+ num_heads: 8
36
+ num_layers: 8
37
+ time_embedding_scale: 1.0
38
+ chunk_size: 5
39
+ use_text_cond: True
40
+ text_len: 128
41
+ drop_out: 0.1
42
+ cfg_scale: 5.0
43
+ prediction_type: "vel"
44
+ causal: False
ldf_models/__init__.py ADDED
File without changes
ldf_models/diffusion_forcing_wan_tiny.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import AutoTokenizer, AutoModel
8
+
9
+ from .tools.wan_model import WanModel
10
+
11
+
12
+ class HFT5Encoder:
13
+ """Wrapper for HuggingFace T5 encoder, compatible with original T5EncoderModel interface"""
14
+ def __init__(self, text_len, dtype=torch.float32, device=torch.device("cpu"), model_name="google/umt5-base"):
15
+ self.text_len = text_len
16
+ self.dtype = dtype
17
+ self.device = device
18
+
19
+ print(f"Loading {model_name} from HuggingFace...")
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ self.model = AutoModel.from_pretrained(
22
+ model_name,
23
+ dtype=dtype
24
+ ).encoder # Only use the encoder part
25
+ self.model.eval()
26
+ self.model.requires_grad_(False)
27
+ self.model.to(device)
28
+
29
+ def __call__(self, texts, device):
30
+ """Encode texts, returns list of tensors (one per text, with padding removed)"""
31
+ # Tokenize
32
+ inputs = self.tokenizer(
33
+ texts,
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=self.text_len,
37
+ return_tensors="pt"
38
+ )
39
+ ids = inputs.input_ids.to(device)
40
+ mask = inputs.attention_mask.to(device)
41
+
42
+ # Encode (model should already be on device via external .model.to(device) call)
43
+ context = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
44
+
45
+ # Get sequence lengths (excluding padding)
46
+ seq_lens = mask.sum(dim=1).long()
47
+
48
+ # Return list of tensors with padding removed (same as original T5EncoderModel)
49
+ return [u[:v] for u, v in zip(context, seq_lens)]
50
+
51
+
52
+ class DiffForcingWanModel(nn.Module):
53
+ def __init__(
54
+ self,
55
+ model_name="google/umt5-base", # HuggingFace model name
56
+ input_dim=256,
57
+ hidden_dim=1024,
58
+ ffn_dim=2048,
59
+ freq_dim=256,
60
+ num_heads=8,
61
+ num_layers=8,
62
+ time_embedding_scale=1.0,
63
+ chunk_size=5,
64
+ noise_steps=10,
65
+ use_text_cond=True,
66
+ text_len=512,
67
+ drop_out=0.1,
68
+ cfg_scale=5.0,
69
+ prediction_type="vel", # "vel", "x0", "noise"
70
+ causal=False,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.input_dim = input_dim
75
+ self.hidden_dim = hidden_dim
76
+ self.ffn_dim = ffn_dim
77
+ self.freq_dim = freq_dim
78
+ self.num_heads = num_heads
79
+ self.num_layers = num_layers
80
+ self.time_embedding_scale = time_embedding_scale
81
+ self.chunk_size = chunk_size
82
+ self.noise_steps = noise_steps
83
+ self.use_text_cond = use_text_cond
84
+ self.drop_out = drop_out
85
+ self.cfg_scale = cfg_scale
86
+ self.prediction_type = prediction_type
87
+ self.causal = causal
88
+
89
+ self.text_dim = 768 # umt5-base hidden size
90
+ self.text_len = text_len
91
+ self.model_name = model_name
92
+
93
+ # Load model and tokenizer from HuggingFace
94
+ print(f"Loading {model_name} from HuggingFace...")
95
+ self.text_encoder = HFT5Encoder(
96
+ text_len=text_len,
97
+ dtype=torch.bfloat16,
98
+ device=torch.device("cpu"),
99
+ model_name=model_name,
100
+ )
101
+
102
+ # Text encoding cache
103
+ self.text_cache = {}
104
+ self.model = WanModel(
105
+ model_type="t2v",
106
+ patch_size=(1, 1, 1),
107
+ text_len=self.text_len,
108
+ in_dim=self.input_dim,
109
+ dim=self.hidden_dim,
110
+ ffn_dim=self.ffn_dim,
111
+ freq_dim=self.freq_dim,
112
+ text_dim=self.text_dim,
113
+ out_dim=self.input_dim,
114
+ num_heads=self.num_heads,
115
+ num_layers=self.num_layers,
116
+ window_size=(-1, -1),
117
+ qk_norm=True,
118
+ cross_attn_norm=True,
119
+ eps=1e-6,
120
+ causal=self.causal,
121
+ )
122
+ self.param_dtype = torch.float32
123
+
124
+ def encode_text_with_cache(self, text_list, device):
125
+ """Encode text using cache
126
+ Args:
127
+ text_list: List[str], list of texts
128
+ device: torch.device
129
+ Returns:
130
+ List[Tensor]: List of encoded text features
131
+ """
132
+ text_features = []
133
+ indices_to_encode = []
134
+ texts_to_encode = []
135
+
136
+ # Check cache
137
+ for i, text in enumerate(text_list):
138
+ if text in self.text_cache:
139
+ # Get from cache and move to correct device
140
+ cached_feature = self.text_cache[text].to(device)
141
+ text_features.append(cached_feature)
142
+ else:
143
+ # Need to encode
144
+ text_features.append(None)
145
+ indices_to_encode.append(i)
146
+ texts_to_encode.append(text)
147
+
148
+ # Batch encode uncached texts
149
+ if texts_to_encode:
150
+ self.text_encoder.model.to(device)
151
+ encoded = self.text_encoder(texts_to_encode, device)
152
+
153
+ # Store in cache and update results
154
+ for idx, text, feature in zip(indices_to_encode, texts_to_encode, encoded):
155
+ # Cache to CPU to save GPU memory
156
+ self.text_cache[text] = feature.cpu()
157
+ text_features[idx] = feature
158
+
159
+ return text_features
160
+
161
+ def preprocess(self, x):
162
+ # (bs, T, C) -> (bs, C, T, 1, 1)
163
+ x = x.permute(0, 2, 1)[:, :, :, None, None]
164
+ return x
165
+
166
+ def postprocess(self, x):
167
+ # (bs, C, T, 1, 1) -> (bs, T, C)
168
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(x.size(0), x.size(2), -1)
169
+ return x
170
+
171
+ def _get_noise_levels(self, device, seq_len, time_steps):
172
+ """Get noise levels"""
173
+ # noise_level[i] = clip(1 + i / chunk_size - time_steps, 0, 1)
174
+ noise_level = torch.clamp(
175
+ 1
176
+ + torch.arange(seq_len, device=device) / self.chunk_size
177
+ - time_steps.unsqueeze(1),
178
+ min=0.0,
179
+ max=1.0,
180
+ )
181
+ return noise_level
182
+
183
+ def add_noise(self, x, noise_level):
184
+ """Add noise
185
+ Args:
186
+ x: (B, T, D)
187
+ noise_level: (B, T)
188
+ """
189
+ noise = torch.randn_like(x)
190
+ # noise_level: (B, T) -> (B, T, 1)
191
+ noise_level = noise_level.unsqueeze(-1)
192
+ noisy_x = x * (1 - noise_level) + noise_level * noise
193
+ return noisy_x, noise
194
+
195
+ def forward(self, x):
196
+ feature = x["feature"] # (B, T, C)
197
+ feature_length = x["feature_length"] # (B,)
198
+ batch_size, seq_len, _ = feature.shape
199
+ device = feature.device
200
+
201
+ # Randomly use a time step
202
+ time_steps = []
203
+ for i in range(batch_size):
204
+ valid_len = feature_length[i].item()
205
+ # Random float from 0 to valid_len/chunk_size, not an integer
206
+ max_time = valid_len / self.chunk_size
207
+ # max_time = valid_len / self.chunk_size + 1
208
+ time_steps.append(torch.FloatTensor(1).uniform_(0, max_time).item())
209
+ time_steps = torch.tensor(time_steps, device=device) # (B,)
210
+ noise_level = self._get_noise_levels(device, seq_len, time_steps) # (B, T)
211
+
212
+ # # Debug: Print noise levels
213
+ # print("Time steps and corresponding noise levels:")
214
+ # for i in range(batch_size):
215
+ # t = time_steps[i].item()
216
+ # # Get noise level at each position
217
+ # start_idx = int(self.chunk_size * (t - 1))
218
+ # end_idx = int(self.chunk_size * t) + 2
219
+ # # Limit to valid range
220
+ # start_idx = max(0, start_idx)
221
+ # end_idx = min(seq_len, end_idx)
222
+ # print(time_steps[i])
223
+ # print(noise_level[i, start_idx:end_idx])
224
+
225
+ # Add noise to entire sequence
226
+ noisy_feature, noise = self.add_noise(feature, noise_level) # (B, T, D)
227
+
228
+ # Debug: Print noise addition information
229
+ # print("Added noise levels at chunk positions:")
230
+ # for i in range(batch_size):
231
+ # t = time_steps[i].item()
232
+ # start_idx = int(self.chunk_size * (t - 1))
233
+ # end_idx = int(self.chunk_size * t) + 2
234
+ # # Limit to valid range
235
+ # start_idx = max(0, start_idx)
236
+ # end_idx = min(seq_len, end_idx)
237
+ # test1 = (
238
+ # feature[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
239
+ # )
240
+ # test2 = (
241
+ # noise[i, start_idx:end_idx, :] - noisy_feature[i, start_idx:end_idx, :]
242
+ # )
243
+ # # Compute length on last dimension
244
+ # print(test1.norm(dim=-1))
245
+ # print(test2.norm(dim=-1))
246
+
247
+ feature = self.preprocess(feature) # (B, C, T, 1, 1)
248
+ noisy_feature = self.preprocess(noisy_feature) # (B, C, T, 1, 1)
249
+ noise = self.preprocess(noise) # (B, C, T, 1, 1)
250
+
251
+ feature_ref = []
252
+ noise_ref = []
253
+ noisy_feature_input = []
254
+ for i in range(batch_size):
255
+ t = time_steps[i].item()
256
+ end_index = int(self.chunk_size * t) + 1
257
+ valid_len = feature_length[i].item()
258
+ end_index = min(valid_len, end_index)
259
+ feature_ref.append(feature[i, :, :end_index, ...])
260
+ noise_ref.append(noise[i, :, :end_index, ...])
261
+ noisy_feature_input.append(noisy_feature[i, :, :end_index, ...])
262
+
263
+ # Encode text condition (using cache)
264
+ if self.use_text_cond and "text" in x:
265
+ text_list = x["text"] # List[str] or List[List[str]]
266
+ if isinstance(text_list[0], list):
267
+ text_end_list = x["feature_text_end"]
268
+ all_text_context = []
269
+ for single_text_list, single_text_end_list in zip(
270
+ text_list, text_end_list
271
+ ):
272
+ if np.random.rand() > self.drop_out:
273
+ single_text_end_list = [0] + [
274
+ min(t, seq_len) for t in single_text_end_list
275
+ ]
276
+ else:
277
+ single_text_list = [""]
278
+ single_text_end_list = [0, seq_len]
279
+ single_text_length_list = [
280
+ t - b
281
+ for t, b in zip(
282
+ single_text_end_list[1:], single_text_end_list[:-1]
283
+ )
284
+ ]
285
+ single_text_context = self.encode_text_with_cache(
286
+ single_text_list, device
287
+ )
288
+ single_text_context = [
289
+ u.to(self.param_dtype) for u in single_text_context
290
+ ]
291
+ for u, duration in zip(
292
+ single_text_context, single_text_length_list
293
+ ):
294
+ all_text_context.extend([u for _ in range(duration)])
295
+ all_text_context.extend(
296
+ [
297
+ single_text_context[-1]
298
+ for _ in range(seq_len - single_text_end_list[-1])
299
+ ]
300
+ )
301
+ else:
302
+ all_text_context = [
303
+ (u if np.random.rand() > self.drop_out else "") for u in text_list
304
+ ]
305
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
306
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
307
+ else:
308
+ all_text_context = [""] * batch_size
309
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
310
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
311
+
312
+ # Through WanModel
313
+ predicted_result = self.model(
314
+ noisy_feature_input,
315
+ noise_level * self.time_embedding_scale,
316
+ all_text_context,
317
+ seq_len,
318
+ y=None,
319
+ ) # (B, C, T, 1, 1)
320
+
321
+ loss = 0.0
322
+ for b in range(batch_size):
323
+ if self.prediction_type == "vel":
324
+ vel = feature_ref[b] - noise_ref[b] # (C, input_length, 1, 1)
325
+ squared_error = (
326
+ predicted_result[b][:, -self.chunk_size :, ...]
327
+ - vel[:, -self.chunk_size :, ...]
328
+ ) ** 2
329
+ elif self.prediction_type == "x0":
330
+ squared_error = (
331
+ predicted_result[b][:, -self.chunk_size :, ...]
332
+ - feature_ref[b][:, -self.chunk_size :, ...]
333
+ ) ** 2
334
+ elif self.prediction_type == "noise":
335
+ squared_error = (
336
+ predicted_result[b][:, -self.chunk_size :, ...]
337
+ - noise_ref[b][:, -self.chunk_size :, ...]
338
+ ) ** 2
339
+ sample_loss = squared_error.sum().mean()
340
+ loss += sample_loss
341
+ loss = loss / batch_size
342
+
343
+ loss_dict = {"total": loss, "mse": loss}
344
+ return loss_dict
345
+
346
+ def generate(self, x, num_denoise_steps=None):
347
+ """
348
+ Generation - Diffusion Forcing inference
349
+ Uses triangular noise schedule, progressively generating from left to right
350
+
351
+ Generation process:
352
+ 1. Start from t=0, gradually increase t
353
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
354
+ 3. After each denoising step, t increases slightly and continues
355
+ """
356
+ feature_length = x["feature_length"]
357
+ batch_size = len(feature_length)
358
+ seq_len = max(feature_length).item()
359
+
360
+ # # debug
361
+ # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
362
+ # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
363
+ # text = x["text"]
364
+ # text_end = x["feature_text_end"]
365
+ # print(text)
366
+ # print(text_end)
367
+ # print(batch_size, seq_len, self.chunk_size)
368
+
369
+ if num_denoise_steps is None:
370
+ num_denoise_steps = self.noise_steps
371
+ assert num_denoise_steps % self.chunk_size == 0
372
+
373
+ device = next(self.parameters()).device
374
+
375
+ # Initialize entire sequence as pure noise
376
+ generated = torch.randn(
377
+ batch_size, seq_len + self.chunk_size, self.input_dim, device=device
378
+ )
379
+ generated = self.preprocess(generated) # (B, C, T, 1, 1)
380
+
381
+ # Calculate total number of time steps needed
382
+ max_t = 1 + (seq_len - 1) / self.chunk_size
383
+
384
+ # Step size for each advancement
385
+ dt = 1 / num_denoise_steps
386
+ total_steps = int(max_t / dt)
387
+
388
+ # Encode text condition (using cache)
389
+ if self.use_text_cond and "text" in x:
390
+ text_list = x["text"] # List[str] or List[List[str]]
391
+ if isinstance(text_list[0], list):
392
+ generated_length = []
393
+ text_end_list = x["feature_text_end"]
394
+ full_text = []
395
+ all_text_context = []
396
+ for single_text_list, single_text_end_list in zip(
397
+ text_list, text_end_list
398
+ ):
399
+ single_text_end_list = [0] + [
400
+ min(t, seq_len) for t in single_text_end_list
401
+ ]
402
+ generated_length.append(single_text_end_list[-1])
403
+ single_text_length_list = [
404
+ t - b
405
+ for t, b in zip(
406
+ single_text_end_list[1:], single_text_end_list[:-1]
407
+ )
408
+ ]
409
+ full_text.append(
410
+ " ////////// ".join(
411
+ [
412
+ f"{u} //dur:{t}"
413
+ for u, t in zip(
414
+ single_text_list, single_text_length_list
415
+ )
416
+ ]
417
+ )
418
+ )
419
+ single_text_context = self.encode_text_with_cache(
420
+ single_text_list, device
421
+ )
422
+ single_text_context = [
423
+ u.to(self.param_dtype) for u in single_text_context
424
+ ]
425
+ for u, duration in zip(
426
+ single_text_context, single_text_length_list
427
+ ):
428
+ all_text_context.extend([u for _ in range(duration)])
429
+ all_text_context.extend(
430
+ [
431
+ single_text_context[-1]
432
+ for _ in range(
433
+ seq_len + self.chunk_size - single_text_end_list[-1]
434
+ )
435
+ ]
436
+ )
437
+ else:
438
+ generated_length = feature_length
439
+ full_text = text_list
440
+ all_text_context = self.encode_text_with_cache(text_list, device)
441
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
442
+ else:
443
+ generated_length = feature_length
444
+ full_text = [""] * batch_size
445
+ all_text_context = [""] * batch_size
446
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
447
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
448
+
449
+ # Get empty text condition encoding (for CFG)
450
+ text_null_list = [""] * batch_size
451
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
452
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
453
+
454
+ # print(len(all_text_context), len(text_null_context))
455
+
456
+ # Progressively advance from t=0 to t=max_t
457
+ for step in range(total_steps):
458
+ # Current time step
459
+ t = step * dt
460
+ start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
461
+ end_index = int(self.chunk_size * t) + 1
462
+ time_steps = torch.full((batch_size,), t, device=device)
463
+
464
+ # Calculate current noise schedule
465
+ noise_level = self._get_noise_levels(
466
+ device, seq_len + self.chunk_size, time_steps
467
+ ) # (B, T)
468
+
469
+ # Predict noise through WanModel
470
+ noisy_input = []
471
+ for i in range(batch_size):
472
+ noisy_input.append(generated[i, :, :end_index, ...])
473
+
474
+ predicted_result = self.model(
475
+ noisy_input,
476
+ noise_level * self.time_embedding_scale,
477
+ all_text_context,
478
+ seq_len + self.chunk_size,
479
+ y=None,
480
+ ) # (B, C, T, 1, 1)
481
+
482
+ # Adjust using CFG
483
+ if self.cfg_scale != 1.0:
484
+ predicted_result_null = self.model(
485
+ noisy_input,
486
+ noise_level * self.time_embedding_scale,
487
+ text_null_context,
488
+ seq_len + self.chunk_size,
489
+ y=None,
490
+ ) # (B, C, T, 1, 1)
491
+ predicted_result = [
492
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
493
+ for pv, pvn in zip(predicted_result, predicted_result_null)
494
+ ]
495
+
496
+ for i in range(batch_size):
497
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
498
+ if self.prediction_type == "vel":
499
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
500
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
501
+ elif self.prediction_type == "x0":
502
+ predicted_vel = (
503
+ predicted_result_i[:, start_index:end_index, ...]
504
+ - generated[i, :, start_index:end_index, ...]
505
+ ) / (
506
+ noise_level[i, start_index:end_index]
507
+ .unsqueeze(0)
508
+ .unsqueeze(-1)
509
+ .unsqueeze(-1)
510
+ )
511
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
512
+ elif self.prediction_type == "noise":
513
+ predicted_vel = (
514
+ generated[i, :, start_index:end_index, ...]
515
+ - predicted_result_i[:, start_index:end_index, ...]
516
+ ) / (
517
+ 1
518
+ + dt
519
+ - noise_level[i, start_index:end_index]
520
+ .unsqueeze(0)
521
+ .unsqueeze(-1)
522
+ .unsqueeze(-1)
523
+ )
524
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
525
+
526
+ generated = self.postprocess(generated) # (B, T, C)
527
+ y_hat_out = []
528
+ for i in range(batch_size):
529
+ # cut off the padding
530
+ single_generated = generated[i, : generated_length[i], :]
531
+ y_hat_out.append(single_generated)
532
+ out = {}
533
+ out["generated"] = y_hat_out
534
+ out["text"] = full_text
535
+
536
+ return out
537
+
538
+ @torch.no_grad()
539
+ def stream_generate(self, x, num_denoise_steps=None):
540
+ """
541
+ Streaming generation - Diffusion Forcing inference
542
+ Uses triangular noise schedule, progressively generating from left to right
543
+
544
+ Generation process:
545
+ 1. Start from t=0, gradually increase t
546
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
547
+ 3. After each denoising step, t increases slightly and continues
548
+ """
549
+ feature_length = x["feature_length"]
550
+ batch_size = len(feature_length)
551
+ seq_len = max(feature_length).item()
552
+
553
+ # # debug
554
+ # x["text"] = [["walk forward.", "sit down.", "stand up."] for _ in range(batch_size)]
555
+ # x["feature_text_end"] = [[1, 2, 3] for _ in range(batch_size)]
556
+ # text = x["text"]
557
+ # text_end = x["feature_text_end"]
558
+ # print(text)
559
+ # print(text_end)
560
+ # print(batch_size, seq_len, self.chunk_size)
561
+
562
+ if num_denoise_steps is None:
563
+ num_denoise_steps = self.noise_steps
564
+ assert num_denoise_steps % self.chunk_size == 0
565
+
566
+ device = next(self.parameters()).device
567
+
568
+ # Initialize entire sequence as pure noise
569
+ generated = torch.randn(
570
+ batch_size, seq_len + self.chunk_size, self.input_dim, device=device
571
+ )
572
+ generated = self.preprocess(generated) # (B, C, T, 1, 1)
573
+
574
+ # Calculate total number of time steps needed
575
+ max_t = 1 + (seq_len - 1) / self.chunk_size
576
+
577
+ # Step size for each advancement
578
+ dt = 1 / num_denoise_steps
579
+ total_steps = int(max_t / dt)
580
+
581
+ # Encode text condition (using cache)
582
+ if self.use_text_cond and "text" in x:
583
+ text_list = x["text"] # List[str] or List[List[str]]
584
+ if isinstance(text_list[0], list):
585
+ generated_length = []
586
+ text_end_list = x["feature_text_end"]
587
+ full_text = []
588
+ all_text_context = []
589
+ for single_text_list, single_text_end_list in zip(
590
+ text_list, text_end_list
591
+ ):
592
+ single_text_end_list = [0] + [
593
+ min(t, seq_len) for t in single_text_end_list
594
+ ]
595
+ generated_length.append(single_text_end_list[-1])
596
+ single_text_length_list = [
597
+ t - b
598
+ for t, b in zip(
599
+ single_text_end_list[1:], single_text_end_list[:-1]
600
+ )
601
+ ]
602
+ full_text.append(
603
+ " ////////// ".join(
604
+ [
605
+ f"{u} //dur:{t}"
606
+ for u, t in zip(
607
+ single_text_list, single_text_length_list
608
+ )
609
+ ]
610
+ )
611
+ )
612
+ single_text_context = self.encode_text_with_cache(
613
+ single_text_list, device
614
+ )
615
+ single_text_context = [
616
+ u.to(self.param_dtype) for u in single_text_context
617
+ ]
618
+ for u, duration in zip(
619
+ single_text_context, single_text_length_list
620
+ ):
621
+ all_text_context.extend([u for _ in range(duration)])
622
+ all_text_context.extend(
623
+ [
624
+ single_text_context[-1]
625
+ for _ in range(
626
+ seq_len + self.chunk_size - single_text_end_list[-1]
627
+ )
628
+ ]
629
+ )
630
+ else:
631
+ generated_length = feature_length
632
+ full_text = text_list
633
+ all_text_context = self.encode_text_with_cache(text_list, device)
634
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
635
+ else:
636
+ generated_length = feature_length
637
+ full_text = [""] * batch_size
638
+ all_text_context = [""] * batch_size
639
+ all_text_context = self.encode_text_with_cache(all_text_context, device)
640
+ all_text_context = [u.to(self.param_dtype) for u in all_text_context]
641
+
642
+ # Get empty text condition encoding (for CFG)
643
+ text_null_list = [""] * batch_size
644
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
645
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
646
+
647
+ # print(len(all_text_context), len(text_null_context))
648
+
649
+ commit_index = 0
650
+ # Progressively advance from t=0 to t=max_t
651
+ for step in range(total_steps):
652
+ # Current time step
653
+ t = step * dt
654
+ start_index = max(0, int(self.chunk_size * (t - 1)) + 1)
655
+ end_index = int(self.chunk_size * t) + 1
656
+ time_steps = torch.full((batch_size,), t, device=device)
657
+
658
+ # Calculate current noise schedule
659
+ noise_level = self._get_noise_levels(
660
+ device, seq_len + self.chunk_size, time_steps
661
+ ) # (B, T)
662
+
663
+ # Predict noise through WanModel
664
+ noisy_input = []
665
+ for i in range(batch_size):
666
+ noisy_input.append(generated[i, :, :end_index, ...])
667
+
668
+ predicted_result = self.model(
669
+ noisy_input,
670
+ noise_level * self.time_embedding_scale,
671
+ all_text_context,
672
+ seq_len + self.chunk_size,
673
+ y=None,
674
+ ) # (B, C, T, 1, 1)
675
+
676
+ # Adjust using CFG
677
+ if self.cfg_scale != 1.0:
678
+ predicted_result_null = self.model(
679
+ noisy_input,
680
+ noise_level * self.time_embedding_scale,
681
+ text_null_context,
682
+ seq_len + self.chunk_size,
683
+ y=None,
684
+ ) # (B, C, T, 1, 1)
685
+ predicted_result = [
686
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
687
+ for pv, pvn in zip(predicted_result, predicted_result_null)
688
+ ]
689
+
690
+ for i in range(batch_size):
691
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
692
+ if self.prediction_type == "vel":
693
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
694
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
695
+ elif self.prediction_type == "x0":
696
+ predicted_vel = (
697
+ predicted_result_i[:, start_index:end_index, ...]
698
+ - generated[i, :, start_index:end_index, ...]
699
+ ) / (
700
+ noise_level[i, start_index:end_index]
701
+ .unsqueeze(0)
702
+ .unsqueeze(-1)
703
+ .unsqueeze(-1)
704
+ )
705
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
706
+ elif self.prediction_type == "noise":
707
+ predicted_vel = (
708
+ generated[i, :, start_index:end_index, ...]
709
+ - predicted_result_i[:, start_index:end_index, ...]
710
+ ) / (
711
+ 1
712
+ + dt
713
+ - noise_level[i, start_index:end_index]
714
+ .unsqueeze(0)
715
+ .unsqueeze(-1)
716
+ .unsqueeze(-1)
717
+ )
718
+ generated[i, :, start_index:end_index, ...] += predicted_vel * dt
719
+
720
+ if commit_index < start_index:
721
+ output = generated[:, :, commit_index:start_index, ...]
722
+ output = self.postprocess(output) # (B, T, C)
723
+ y_hat_out = []
724
+ for i in range(batch_size):
725
+ if commit_index < generated_length[i]:
726
+ y_hat_out.append(
727
+ output[i, : generated_length[i] - commit_index, ...]
728
+ )
729
+ else:
730
+ y_hat_out.append(None)
731
+
732
+ out = {}
733
+ out["generated"] = y_hat_out
734
+ yield out
735
+ commit_index = start_index
736
+
737
+ output = generated[:, :, commit_index:, ...]
738
+ output = self.postprocess(output) # (B, T_remain, C)
739
+ y_hat_out = []
740
+ for i in range(batch_size):
741
+ if commit_index < generated_length[i]:
742
+ y_hat_out.append(output[i, : generated_length[i] - commit_index, ...])
743
+ else:
744
+ y_hat_out.append(None)
745
+ out = {}
746
+ out["generated"] = y_hat_out
747
+ yield out
748
+
749
+ def init_generated(self, seq_len, batch_size=1, num_denoise_steps=None):
750
+ self.seq_len = seq_len
751
+ self.batch_size = batch_size
752
+ if num_denoise_steps is None:
753
+ self.num_denoise_steps = self.noise_steps
754
+ else:
755
+ self.num_denoise_steps = num_denoise_steps
756
+ assert self.num_denoise_steps % self.chunk_size == 0
757
+ self.dt = 1 / self.num_denoise_steps
758
+ self.current_step = 0
759
+ self.text_condition_list = [[] for _ in range(self.batch_size)]
760
+ self.generated = torch.randn(
761
+ self.batch_size, self.seq_len * 2 + self.chunk_size, self.input_dim
762
+ )
763
+ self.generated = self.preprocess(self.generated) # (B, C, T, 1, 1)
764
+ self.commit_index = 0
765
+
766
+ @torch.no_grad()
767
+ def stream_generate_step(self, x, first_chunk=True):
768
+ """
769
+ Streaming generation step - Diffusion Forcing inference
770
+ Uses triangular noise schedule, progressively generating from left to right
771
+
772
+ Generation process:
773
+ 1. Start from t=0, gradually increase t
774
+ 2. Each t corresponds to a noise schedule: clean on left, noisy on right, gradient in middle
775
+ 3. After each denoising step, t increases slightly and continues
776
+ """
777
+
778
+ device = next(self.parameters()).device
779
+ if first_chunk:
780
+ self.generated = self.generated.to(device)
781
+
782
+ # Encode text condition (using cache)
783
+ if self.use_text_cond and "text" in x:
784
+ text_list = x["text"] # List[str]
785
+ new_text_context = self.encode_text_with_cache(text_list, device)
786
+ new_text_context = [u.to(self.param_dtype) for u in new_text_context]
787
+ else:
788
+ new_text_context = [""] * self.batch_size
789
+ new_text_context = self.encode_text_with_cache(new_text_context, device)
790
+ new_text_context = [u.to(self.param_dtype) for u in new_text_context]
791
+
792
+ # Get empty text condition encoding (for CFG)
793
+ text_null_list = [""] * self.batch_size
794
+ text_null_context = self.encode_text_with_cache(text_null_list, device)
795
+ text_null_context = [u.to(self.param_dtype) for u in text_null_context]
796
+
797
+ for i in range(self.batch_size):
798
+ if first_chunk:
799
+ self.text_condition_list[i].extend(
800
+ [new_text_context[i]] * self.chunk_size
801
+ )
802
+ else:
803
+ self.text_condition_list[i].extend([new_text_context[i]])
804
+
805
+ end_step = (
806
+ (self.commit_index + self.chunk_size)
807
+ * self.num_denoise_steps
808
+ / self.chunk_size
809
+ )
810
+ while self.current_step < end_step:
811
+ current_time = self.current_step * self.dt
812
+ start_index = max(0, int(self.chunk_size * (current_time - 1)) + 1)
813
+ end_index = int(self.chunk_size * current_time) + 1
814
+ time_steps = torch.full((self.batch_size,), current_time, device=device)
815
+
816
+ noise_level = self._get_noise_levels(device, end_index, time_steps)[
817
+ :, -self.seq_len :
818
+ ] # (B, T)
819
+
820
+ # Predict noise through WanModel
821
+ noisy_input = []
822
+ for i in range(self.batch_size):
823
+ noisy_input.append(
824
+ self.generated[i, :, :end_index, ...][:, -self.seq_len :]
825
+ ) # (C, T, 1, 1)
826
+
827
+ text_condition = []
828
+ for i in range(self.batch_size):
829
+ text_condition.extend(
830
+ self.text_condition_list[i][:end_index][-self.seq_len :]
831
+ ) # (T, D, 4096)
832
+
833
+ # print("////////////////////")
834
+ # print("current step: ", self.current_step)
835
+ # print("chunk size: ", self.chunk_size)
836
+ # print("start_index: ", start_index)
837
+ # print("end_index: ", end_index)
838
+ # print("noisy_input shape: ", noisy_input[0].shape)
839
+ # print("noise_level: ", noise_level[0, start_index:end_index])
840
+ # print("text_condition shape: ", len(text_condition))
841
+ # print("commit_index: ", self.commit_index)
842
+ # print("////////////////////")
843
+
844
+ predicted_result = self.model(
845
+ noisy_input,
846
+ noise_level * self.time_embedding_scale,
847
+ text_condition,
848
+ min(end_index, self.seq_len),
849
+ y=None,
850
+ ) # (B, C, T, 1, 1)
851
+
852
+ # Adjust using CFG
853
+ if self.cfg_scale != 1.0:
854
+ predicted_result_null = self.model(
855
+ noisy_input,
856
+ noise_level * self.time_embedding_scale,
857
+ text_null_context,
858
+ min(end_index, self.seq_len),
859
+ y=None,
860
+ ) # (B, C, T, 1, 1)
861
+ predicted_result = [
862
+ self.cfg_scale * pv - (self.cfg_scale - 1) * pvn
863
+ for pv, pvn in zip(predicted_result, predicted_result_null)
864
+ ]
865
+
866
+ for i in range(self.batch_size):
867
+ predicted_result_i = predicted_result[i] # (C, input_length, 1, 1)
868
+ if end_index > self.seq_len:
869
+ predicted_result_i = torch.cat(
870
+ [
871
+ torch.zeros(
872
+ predicted_result_i.shape[0],
873
+ end_index - self.seq_len,
874
+ predicted_result_i.shape[2],
875
+ predicted_result_i.shape[3],
876
+ device=device,
877
+ ),
878
+ predicted_result_i,
879
+ ],
880
+ dim=1,
881
+ )
882
+ if self.prediction_type == "vel":
883
+ predicted_vel = predicted_result_i[:, start_index:end_index, ...]
884
+ self.generated[i, :, start_index:end_index, ...] += (
885
+ predicted_vel * self.dt
886
+ )
887
+ elif self.prediction_type == "x0":
888
+ predicted_vel = (
889
+ predicted_result_i[:, start_index:end_index, ...]
890
+ - self.generated[i, :, start_index:end_index, ...]
891
+ ) / (
892
+ noise_level[i, start_index:end_index]
893
+ .unsqueeze(0)
894
+ .unsqueeze(-1)
895
+ .unsqueeze(-1)
896
+ )
897
+ self.generated[i, :, start_index:end_index, ...] += (
898
+ predicted_vel * self.dt
899
+ )
900
+ elif self.prediction_type == "noise":
901
+ predicted_vel = (
902
+ self.generated[i, :, start_index:end_index, ...]
903
+ - predicted_result_i[:, start_index:end_index, ...]
904
+ ) / (
905
+ 1
906
+ + self.dt
907
+ - noise_level[i, start_index:end_index]
908
+ .unsqueeze(0)
909
+ .unsqueeze(-1)
910
+ .unsqueeze(-1)
911
+ )
912
+ self.generated[i, :, start_index:end_index, ...] += (
913
+ predicted_vel * self.dt
914
+ )
915
+ self.current_step += 1
916
+ output = self.generated[:, :, self.commit_index : self.commit_index + 1, ...]
917
+ output = self.postprocess(output) # (B, 1, C)
918
+ out = {}
919
+ out["generated"] = output
920
+ self.commit_index += 1
921
+
922
+ if self.commit_index == self.seq_len * 2:
923
+ self.generated = torch.cat(
924
+ [
925
+ self.generated[:, :, self.seq_len :, ...],
926
+ torch.randn(
927
+ self.batch_size,
928
+ self.input_dim,
929
+ self.seq_len,
930
+ 1,
931
+ 1,
932
+ device=device,
933
+ ),
934
+ ],
935
+ dim=2,
936
+ )
937
+ self.current_step -= self.seq_len * self.num_denoise_steps / self.chunk_size
938
+ self.commit_index -= self.seq_len
939
+ for i in range(self.batch_size):
940
+ self.text_condition_list[i] = self.text_condition_list[i][
941
+ self.seq_len :
942
+ ]
943
+ return out
ldf_models/tools/attention.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ import warnings
19
+
20
+ __all__ = [
21
+ "flash_attention",
22
+ "attention",
23
+ ]
24
+
25
+
26
+ def flash_attention(
27
+ q,
28
+ k,
29
+ v,
30
+ q_lens=None,
31
+ k_lens=None,
32
+ dropout_p=0.0,
33
+ softmax_scale=None,
34
+ q_scale=None,
35
+ causal=False,
36
+ window_size=(-1, -1),
37
+ deterministic=False,
38
+ dtype=torch.bfloat16,
39
+ version=None,
40
+ ):
41
+ """
42
+ q: [B, Lq, Nq, C1].
43
+ k: [B, Lk, Nk, C1].
44
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
45
+ q_lens: [B].
46
+ k_lens: [B].
47
+ dropout_p: float. Dropout probability.
48
+ softmax_scale: float. The scaling of QK^T before applying softmax.
49
+ causal: bool. Whether to apply causal attention mask.
50
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
51
+ deterministic: bool. If True, slightly slower and uses more memory.
52
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
53
+ """
54
+ half_dtypes = (torch.float16, torch.bfloat16)
55
+ assert dtype in half_dtypes
56
+ assert q.device.type == "cuda" and q.size(-1) <= 256
57
+
58
+ # params
59
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
60
+
61
+ def half(x):
62
+ return x if x.dtype in half_dtypes else x.to(dtype)
63
+
64
+ # preprocess query
65
+ if q_lens is None:
66
+ q = half(q.flatten(0, 1))
67
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
68
+ device=q.device, non_blocking=True
69
+ )
70
+ else:
71
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
72
+
73
+ # preprocess key, value
74
+ if k_lens is None:
75
+ k = half(k.flatten(0, 1))
76
+ v = half(v.flatten(0, 1))
77
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
78
+ device=k.device, non_blocking=True
79
+ )
80
+ else:
81
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
82
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
83
+
84
+ q = q.to(v.dtype)
85
+ k = k.to(v.dtype)
86
+
87
+ if q_scale is not None:
88
+ q = q * q_scale
89
+
90
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
91
+ warnings.warn(
92
+ "Flash attention 3 is not available, use flash attention 2 instead."
93
+ )
94
+
95
+ # apply attention
96
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
97
+ # Note: dropout_p, window_size are not supported in FA3 now.
98
+ x = flash_attn_interface.flash_attn_varlen_func(
99
+ q=q,
100
+ k=k,
101
+ v=v,
102
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
103
+ .cumsum(0, dtype=torch.int32)
104
+ .to(q.device, non_blocking=True),
105
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
106
+ .cumsum(0, dtype=torch.int32)
107
+ .to(q.device, non_blocking=True),
108
+ seqused_q=None,
109
+ seqused_k=None,
110
+ max_seqlen_q=lq,
111
+ max_seqlen_k=lk,
112
+ softmax_scale=softmax_scale,
113
+ causal=causal,
114
+ deterministic=deterministic,
115
+ )[0].unflatten(0, (b, lq))
116
+ else:
117
+ assert FLASH_ATTN_2_AVAILABLE
118
+ x = flash_attn.flash_attn_varlen_func(
119
+ q=q,
120
+ k=k,
121
+ v=v,
122
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
123
+ .cumsum(0, dtype=torch.int32)
124
+ .to(q.device, non_blocking=True),
125
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
126
+ .cumsum(0, dtype=torch.int32)
127
+ .to(q.device, non_blocking=True),
128
+ max_seqlen_q=lq,
129
+ max_seqlen_k=lk,
130
+ dropout_p=dropout_p,
131
+ softmax_scale=softmax_scale,
132
+ causal=causal,
133
+ window_size=window_size,
134
+ deterministic=deterministic,
135
+ ).unflatten(0, (b, lq))
136
+
137
+ # output
138
+ return x.type(out_dtype)
139
+
140
+
141
+ def attention(
142
+ q,
143
+ k,
144
+ v,
145
+ q_lens=None,
146
+ k_lens=None,
147
+ dropout_p=0.0,
148
+ softmax_scale=None,
149
+ q_scale=None,
150
+ causal=False,
151
+ window_size=(-1, -1),
152
+ deterministic=False,
153
+ dtype=torch.bfloat16,
154
+ fa_version=None,
155
+ ):
156
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
157
+ return flash_attention(
158
+ q=q,
159
+ k=k,
160
+ v=v,
161
+ q_lens=q_lens,
162
+ k_lens=k_lens,
163
+ dropout_p=dropout_p,
164
+ softmax_scale=softmax_scale,
165
+ q_scale=q_scale,
166
+ causal=causal,
167
+ window_size=window_size,
168
+ deterministic=deterministic,
169
+ dtype=dtype,
170
+ version=fa_version,
171
+ )
172
+ else:
173
+ if q_lens is not None or k_lens is not None:
174
+ warnings.warn(
175
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
176
+ )
177
+ attn_mask = None
178
+
179
+ q = q.transpose(1, 2).to(dtype)
180
+ k = k.transpose(1, 2).to(dtype)
181
+ v = v.transpose(1, 2).to(dtype)
182
+
183
+ out = torch.nn.functional.scaled_dot_product_attention(
184
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
185
+ )
186
+
187
+ out = out.transpose(1, 2).contiguous()
188
+ return out
ldf_models/tools/t5.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ "T5Model",
14
+ "T5Encoder",
15
+ "T5Decoder",
16
+ "T5EncoderModel",
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5
44
+ )
45
+
46
+
47
+ class GELU(nn.Module):
48
+ def forward(self, x):
49
+ return (
50
+ 0.5
51
+ * x
52
+ * (
53
+ 1.0
54
+ + torch.tanh(
55
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
56
+ )
57
+ )
58
+ )
59
+
60
+
61
+ class T5LayerNorm(nn.Module):
62
+ def __init__(self, dim, eps=1e-6):
63
+ super(T5LayerNorm, self).__init__()
64
+ self.dim = dim
65
+ self.eps = eps
66
+ self.weight = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x):
69
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
70
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
71
+ x = x.type_as(self.weight)
72
+ return self.weight * x
73
+
74
+
75
+ class T5Attention(nn.Module):
76
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
77
+ assert dim_attn % num_heads == 0
78
+ super(T5Attention, self).__init__()
79
+ self.dim = dim
80
+ self.dim_attn = dim_attn
81
+ self.num_heads = num_heads
82
+ self.head_dim = dim_attn // num_heads
83
+
84
+ # layers
85
+ self.q = nn.Linear(dim, dim_attn, bias=False)
86
+ self.k = nn.Linear(dim, dim_attn, bias=False)
87
+ self.v = nn.Linear(dim, dim_attn, bias=False)
88
+ self.o = nn.Linear(dim_attn, dim, bias=False)
89
+ self.dropout = nn.Dropout(dropout)
90
+
91
+ def forward(self, x, context=None, mask=None, pos_bias=None):
92
+ """
93
+ x: [B, L1, C].
94
+ context: [B, L2, C] or None.
95
+ mask: [B, L2] or [B, L1, L2] or None.
96
+ """
97
+ # check inputs
98
+ context = x if context is None else context
99
+ b, n, c = x.size(0), self.num_heads, self.head_dim
100
+
101
+ # compute query, key, value
102
+ q = self.q(x).view(b, -1, n, c)
103
+ k = self.k(context).view(b, -1, n, c)
104
+ v = self.v(context).view(b, -1, n, c)
105
+
106
+ # attention bias
107
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
108
+ if pos_bias is not None:
109
+ attn_bias += pos_bias
110
+ if mask is not None:
111
+ assert mask.ndim in [2, 3]
112
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
113
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
114
+
115
+ # compute attention (T5 does not use scaling)
116
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
117
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
118
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
119
+
120
+ # output
121
+ x = x.reshape(b, -1, n * c)
122
+ x = self.o(x)
123
+ x = self.dropout(x)
124
+ return x
125
+
126
+
127
+ class T5FeedForward(nn.Module):
128
+ def __init__(self, dim, dim_ffn, dropout=0.1):
129
+ super(T5FeedForward, self).__init__()
130
+ self.dim = dim
131
+ self.dim_ffn = dim_ffn
132
+
133
+ # layers
134
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
135
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
136
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
137
+ self.dropout = nn.Dropout(dropout)
138
+
139
+ def forward(self, x):
140
+ x = self.fc1(x) * self.gate(x)
141
+ x = self.dropout(x)
142
+ x = self.fc2(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
+
147
+ class T5SelfAttention(nn.Module):
148
+ def __init__(
149
+ self,
150
+ dim,
151
+ dim_attn,
152
+ dim_ffn,
153
+ num_heads,
154
+ num_buckets,
155
+ shared_pos=True,
156
+ dropout=0.1,
157
+ ):
158
+ super(T5SelfAttention, self).__init__()
159
+ self.dim = dim
160
+ self.dim_attn = dim_attn
161
+ self.dim_ffn = dim_ffn
162
+ self.num_heads = num_heads
163
+ self.num_buckets = num_buckets
164
+ self.shared_pos = shared_pos
165
+
166
+ # layers
167
+ self.norm1 = T5LayerNorm(dim)
168
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
169
+ self.norm2 = T5LayerNorm(dim)
170
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
171
+ self.pos_embedding = (
172
+ None
173
+ if shared_pos
174
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
175
+ )
176
+
177
+ def forward(self, x, mask=None, pos_bias=None):
178
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
179
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
180
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
181
+ return x
182
+
183
+
184
+ class T5CrossAttention(nn.Module):
185
+ def __init__(
186
+ self,
187
+ dim,
188
+ dim_attn,
189
+ dim_ffn,
190
+ num_heads,
191
+ num_buckets,
192
+ shared_pos=True,
193
+ dropout=0.1,
194
+ ):
195
+ super(T5CrossAttention, self).__init__()
196
+ self.dim = dim
197
+ self.dim_attn = dim_attn
198
+ self.dim_ffn = dim_ffn
199
+ self.num_heads = num_heads
200
+ self.num_buckets = num_buckets
201
+ self.shared_pos = shared_pos
202
+
203
+ # layers
204
+ self.norm1 = T5LayerNorm(dim)
205
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
206
+ self.norm2 = T5LayerNorm(dim)
207
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
208
+ self.norm3 = T5LayerNorm(dim)
209
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
210
+ self.pos_embedding = (
211
+ None
212
+ if shared_pos
213
+ else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
214
+ )
215
+
216
+ def forward(
217
+ self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None
218
+ ):
219
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
220
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
221
+ x = fp16_clamp(
222
+ x
223
+ + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)
224
+ )
225
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
226
+ return x
227
+
228
+
229
+ class T5RelativeEmbedding(nn.Module):
230
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
231
+ super(T5RelativeEmbedding, self).__init__()
232
+ self.num_buckets = num_buckets
233
+ self.num_heads = num_heads
234
+ self.bidirectional = bidirectional
235
+ self.max_dist = max_dist
236
+
237
+ # layers
238
+ self.embedding = nn.Embedding(num_buckets, num_heads)
239
+
240
+ def forward(self, lq, lk):
241
+ device = self.embedding.weight.device
242
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
243
+ # torch.arange(lq).unsqueeze(1).to(device)
244
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(
245
+ lq, device=device
246
+ ).unsqueeze(1)
247
+ rel_pos = self._relative_position_bucket(rel_pos)
248
+ rel_pos_embeds = self.embedding(rel_pos)
249
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
250
+ return rel_pos_embeds.contiguous()
251
+
252
+ def _relative_position_bucket(self, rel_pos):
253
+ # preprocess
254
+ if self.bidirectional:
255
+ num_buckets = self.num_buckets // 2
256
+ rel_buckets = (rel_pos > 0).long() * num_buckets
257
+ rel_pos = torch.abs(rel_pos)
258
+ else:
259
+ num_buckets = self.num_buckets
260
+ rel_buckets = 0
261
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
262
+
263
+ # embeddings for small and large positions
264
+ max_exact = num_buckets // 2
265
+ rel_pos_large = (
266
+ max_exact
267
+ + (
268
+ torch.log(rel_pos.float() / max_exact)
269
+ / math.log(self.max_dist / max_exact)
270
+ * (num_buckets - max_exact)
271
+ ).long()
272
+ )
273
+ rel_pos_large = torch.min(
274
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)
275
+ )
276
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
277
+ return rel_buckets
278
+
279
+
280
+ class T5Encoder(nn.Module):
281
+ def __init__(
282
+ self,
283
+ vocab,
284
+ dim,
285
+ dim_attn,
286
+ dim_ffn,
287
+ num_heads,
288
+ num_layers,
289
+ num_buckets,
290
+ shared_pos=True,
291
+ dropout=0.1,
292
+ ):
293
+ super(T5Encoder, self).__init__()
294
+ self.dim = dim
295
+ self.dim_attn = dim_attn
296
+ self.dim_ffn = dim_ffn
297
+ self.num_heads = num_heads
298
+ self.num_layers = num_layers
299
+ self.num_buckets = num_buckets
300
+ self.shared_pos = shared_pos
301
+
302
+ # layers
303
+ self.token_embedding = (
304
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
305
+ )
306
+ self.pos_embedding = (
307
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
308
+ if shared_pos
309
+ else None
310
+ )
311
+ self.dropout = nn.Dropout(dropout)
312
+ self.blocks = nn.ModuleList(
313
+ [
314
+ T5SelfAttention(
315
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
316
+ )
317
+ for _ in range(num_layers)
318
+ ]
319
+ )
320
+ self.norm = T5LayerNorm(dim)
321
+
322
+ # initialize weights
323
+ self.apply(init_weights)
324
+
325
+ def forward(self, ids, mask=None):
326
+ x = self.token_embedding(ids)
327
+ x = self.dropout(x)
328
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
329
+ for block in self.blocks:
330
+ x = block(x, mask, pos_bias=e)
331
+ x = self.norm(x)
332
+ x = self.dropout(x)
333
+ return x
334
+
335
+
336
+ class T5Decoder(nn.Module):
337
+ def __init__(
338
+ self,
339
+ vocab,
340
+ dim,
341
+ dim_attn,
342
+ dim_ffn,
343
+ num_heads,
344
+ num_layers,
345
+ num_buckets,
346
+ shared_pos=True,
347
+ dropout=0.1,
348
+ ):
349
+ super(T5Decoder, self).__init__()
350
+ self.dim = dim
351
+ self.dim_attn = dim_attn
352
+ self.dim_ffn = dim_ffn
353
+ self.num_heads = num_heads
354
+ self.num_layers = num_layers
355
+ self.num_buckets = num_buckets
356
+ self.shared_pos = shared_pos
357
+
358
+ # layers
359
+ self.token_embedding = (
360
+ vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
361
+ )
362
+ self.pos_embedding = (
363
+ T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
364
+ if shared_pos
365
+ else None
366
+ )
367
+ self.dropout = nn.Dropout(dropout)
368
+ self.blocks = nn.ModuleList(
369
+ [
370
+ T5CrossAttention(
371
+ dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout
372
+ )
373
+ for _ in range(num_layers)
374
+ ]
375
+ )
376
+ self.norm = T5LayerNorm(dim)
377
+
378
+ # initialize weights
379
+ self.apply(init_weights)
380
+
381
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
382
+ b, s = ids.size()
383
+
384
+ # causal mask
385
+ if mask is None:
386
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
387
+ elif mask.ndim == 2:
388
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
389
+
390
+ # layers
391
+ x = self.token_embedding(ids)
392
+ x = self.dropout(x)
393
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
394
+ for block in self.blocks:
395
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
396
+ x = self.norm(x)
397
+ x = self.dropout(x)
398
+ return x
399
+
400
+
401
+ class T5Model(nn.Module):
402
+ def __init__(
403
+ self,
404
+ vocab_size,
405
+ dim,
406
+ dim_attn,
407
+ dim_ffn,
408
+ num_heads,
409
+ encoder_layers,
410
+ decoder_layers,
411
+ num_buckets,
412
+ shared_pos=True,
413
+ dropout=0.1,
414
+ ):
415
+ super(T5Model, self).__init__()
416
+ self.vocab_size = vocab_size
417
+ self.dim = dim
418
+ self.dim_attn = dim_attn
419
+ self.dim_ffn = dim_ffn
420
+ self.num_heads = num_heads
421
+ self.encoder_layers = encoder_layers
422
+ self.decoder_layers = decoder_layers
423
+ self.num_buckets = num_buckets
424
+
425
+ # layers
426
+ self.token_embedding = nn.Embedding(vocab_size, dim)
427
+ self.encoder = T5Encoder(
428
+ self.token_embedding,
429
+ dim,
430
+ dim_attn,
431
+ dim_ffn,
432
+ num_heads,
433
+ encoder_layers,
434
+ num_buckets,
435
+ shared_pos,
436
+ dropout,
437
+ )
438
+ self.decoder = T5Decoder(
439
+ self.token_embedding,
440
+ dim,
441
+ dim_attn,
442
+ dim_ffn,
443
+ num_heads,
444
+ decoder_layers,
445
+ num_buckets,
446
+ shared_pos,
447
+ dropout,
448
+ )
449
+ self.head = nn.Linear(dim, vocab_size, bias=False)
450
+
451
+ # initialize weights
452
+ self.apply(init_weights)
453
+
454
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
455
+ x = self.encoder(encoder_ids, encoder_mask)
456
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
457
+ x = self.head(x)
458
+ return x
459
+
460
+
461
+ def _t5(
462
+ name,
463
+ encoder_only=False,
464
+ decoder_only=False,
465
+ return_tokenizer=False,
466
+ tokenizer_kwargs={},
467
+ dtype=torch.float32,
468
+ device="cpu",
469
+ **kwargs,
470
+ ):
471
+ # sanity check
472
+ assert not (encoder_only and decoder_only)
473
+
474
+ # params
475
+ if encoder_only:
476
+ model_cls = T5Encoder
477
+ kwargs["vocab"] = kwargs.pop("vocab_size")
478
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
479
+ _ = kwargs.pop("decoder_layers")
480
+ elif decoder_only:
481
+ model_cls = T5Decoder
482
+ kwargs["vocab"] = kwargs.pop("vocab_size")
483
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
484
+ _ = kwargs.pop("encoder_layers")
485
+ else:
486
+ model_cls = T5Model
487
+
488
+ # init model
489
+ with torch.device(device):
490
+ model = model_cls(**kwargs)
491
+
492
+ # set device
493
+ model = model.to(dtype=dtype, device=device)
494
+
495
+ # init tokenizer
496
+ if return_tokenizer:
497
+ from .tokenizers import HuggingfaceTokenizer
498
+
499
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
500
+ return model, tokenizer
501
+ else:
502
+ return model
503
+
504
+
505
+ def umt5_xxl(**kwargs):
506
+ cfg = dict(
507
+ vocab_size=256384,
508
+ dim=4096,
509
+ dim_attn=4096,
510
+ dim_ffn=10240,
511
+ num_heads=64,
512
+ encoder_layers=24,
513
+ decoder_layers=24,
514
+ num_buckets=32,
515
+ shared_pos=False,
516
+ dropout=0.1,
517
+ )
518
+ cfg.update(**kwargs)
519
+ return _t5("umt5-xxl", **cfg)
520
+
521
+
522
+ class T5EncoderModel:
523
+ def __init__(
524
+ self,
525
+ text_len,
526
+ dtype=torch.bfloat16,
527
+ device=torch.cuda.current_device(),
528
+ checkpoint_path=None,
529
+ tokenizer_path=None,
530
+ shard_fn=None,
531
+ ):
532
+ self.text_len = text_len
533
+ self.dtype = dtype
534
+ self.device = device
535
+ self.checkpoint_path = checkpoint_path
536
+ self.tokenizer_path = tokenizer_path
537
+
538
+ # init model
539
+ model = (
540
+ umt5_xxl(
541
+ encoder_only=True, return_tokenizer=False, dtype=dtype, device=device
542
+ )
543
+ .eval()
544
+ .requires_grad_(False)
545
+ )
546
+ logging.info(f"loading {checkpoint_path}")
547
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
548
+ self.model = model
549
+ if shard_fn is not None:
550
+ self.model = shard_fn(self.model, sync_module_states=False)
551
+ else:
552
+ self.model.to(self.device)
553
+ # init tokenizer
554
+ self.tokenizer = HuggingfaceTokenizer(
555
+ name=tokenizer_path, seq_len=text_len, clean="whitespace"
556
+ )
557
+
558
+ def __call__(self, texts, device):
559
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
560
+ ids = ids.to(device)
561
+ mask = mask.to(device)
562
+ seq_lens = mask.gt(0).sum(dim=1).long()
563
+ context = self.model(ids, mask)
564
+ return [u[:v] for u, v in zip(context, seq_lens)]
ldf_models/tools/tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ["HuggingfaceTokenizer"]
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r"\s+", " ", text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace("_", " ")
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans("", "", string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string)
30
+ )
31
+ else:
32
+ text = text.translate(str.maketrans("", "", string.punctuation))
33
+ text = text.lower()
34
+ text = re.sub(r"\s+", " ", text)
35
+ return text.strip()
36
+
37
+
38
+ class HuggingfaceTokenizer:
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, "whitespace", "lower", "canonicalize")
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop("return_mask", False)
51
+
52
+ # arguments
53
+ _kwargs = {"return_tensors": "pt"}
54
+ if self.seq_len is not None:
55
+ _kwargs.update(
56
+ {
57
+ "padding": "max_length",
58
+ "truncation": True,
59
+ "max_length": self.seq_len,
60
+ }
61
+ )
62
+ _kwargs.update(**kwargs)
63
+
64
+ # tokenization
65
+ if isinstance(sequence, str):
66
+ sequence = [sequence]
67
+ if self.clean:
68
+ sequence = [self._clean(u) for u in sequence]
69
+ ids = self.tokenizer(sequence, **_kwargs)
70
+
71
+ # output
72
+ if return_mask:
73
+ return ids.input_ids, ids.attention_mask
74
+ else:
75
+ return ids.input_ids
76
+
77
+ def _clean(self, text):
78
+ if self.clean == "whitespace":
79
+ text = whitespace_clean(basic_clean(text))
80
+ elif self.clean == "lower":
81
+ text = whitespace_clean(basic_clean(text)).lower()
82
+ elif self.clean == "canonicalize":
83
+ text = canonicalize(basic_clean(text))
84
+ return text
ldf_models/tools/wan_model.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module uses modified code from Alibaba Wan Team
2
+ # Original source: https://github.com/Wan-Video/Wan2.2
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ # Modified to support stream mode for cross-attention.
5
+ # Added causal attention for self-attention (1d case)
6
+ # Added context length corrrection.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+ from .attention import flash_attention
16
+
17
+
18
+ def sinusoidal_embedding_1d(dim, position):
19
+ # preprocess
20
+ assert dim % 2 == 0
21
+ half = dim // 2
22
+ position = position.type(torch.float64)
23
+
24
+ # calculation
25
+ sinusoid = torch.outer(
26
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half))
27
+ )
28
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
29
+ return x
30
+
31
+
32
+ @torch.amp.autocast("cuda", enabled=False)
33
+ def rope_params(max_seq_len, dim, theta=10000):
34
+ assert dim % 2 == 0
35
+ freqs = torch.outer(
36
+ torch.arange(max_seq_len),
37
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
38
+ )
39
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
40
+ return freqs
41
+
42
+
43
+ @torch.amp.autocast("cuda", enabled=False)
44
+ def rope_apply(x, grid_sizes, freqs):
45
+ n, c = x.size(2), x.size(3) // 2
46
+
47
+ # split freqs
48
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
49
+
50
+ # loop over samples
51
+ output = []
52
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
53
+ seq_len = f * h * w
54
+
55
+ # precompute multipliers
56
+ x_i = torch.view_as_complex(
57
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
58
+ )
59
+ freqs_i = torch.cat(
60
+ [
61
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
62
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
63
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
64
+ ],
65
+ dim=-1,
66
+ ).reshape(seq_len, 1, -1)
67
+
68
+ # apply rotary embedding
69
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
70
+ x_i = torch.cat([x_i, x[i, seq_len:]])
71
+
72
+ # append to collection
73
+ output.append(x_i)
74
+ return torch.stack(output).float()
75
+
76
+
77
+ class WanRMSNorm(nn.Module):
78
+ def __init__(self, dim, eps=1e-5):
79
+ super().__init__()
80
+ self.dim = dim
81
+ self.eps = eps
82
+ self.weight = nn.Parameter(torch.ones(dim))
83
+
84
+ def forward(self, x):
85
+ r"""
86
+ Args:
87
+ x(Tensor): Shape [B, L, C]
88
+ """
89
+ return self._norm(x.float()).type_as(x) * self.weight
90
+
91
+ def _norm(self, x):
92
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
93
+
94
+
95
+ class WanLayerNorm(nn.LayerNorm):
96
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
97
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
98
+
99
+ def forward(self, x):
100
+ r"""
101
+ Args:
102
+ x(Tensor): Shape [B, L, C]
103
+ """
104
+ return super().forward(x.float()).type_as(x)
105
+
106
+
107
+ class WanSelfAttention(nn.Module):
108
+ def __init__(
109
+ self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, causal=False
110
+ ):
111
+ assert dim % num_heads == 0
112
+ super().__init__()
113
+ self.dim = dim
114
+ self.num_heads = num_heads
115
+ self.head_dim = dim // num_heads
116
+ self.window_size = window_size
117
+ self.qk_norm = qk_norm
118
+ self.eps = eps
119
+ self.causal = causal
120
+ # layers
121
+ self.q = nn.Linear(dim, dim)
122
+ self.k = nn.Linear(dim, dim)
123
+ self.v = nn.Linear(dim, dim)
124
+ self.o = nn.Linear(dim, dim)
125
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
127
+
128
+ def forward(self, x, seq_lens, grid_sizes, freqs):
129
+ r"""
130
+ Args:
131
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
132
+ seq_lens(Tensor): Shape [B]
133
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
134
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
135
+ """
136
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
137
+
138
+ # query, key, value function
139
+ def qkv_fn(x):
140
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
141
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
142
+ v = self.v(x).view(b, s, n, d)
143
+ return q, k, v
144
+
145
+ q, k, v = qkv_fn(x)
146
+
147
+ x = flash_attention(
148
+ q=rope_apply(q, grid_sizes, freqs),
149
+ k=rope_apply(k, grid_sizes, freqs),
150
+ v=v,
151
+ k_lens=seq_lens,
152
+ window_size=self.window_size,
153
+ causal=self.causal,
154
+ )
155
+
156
+ # output
157
+ x = x.flatten(2)
158
+ x = self.o(x)
159
+ return x
160
+
161
+
162
+ class WanCrossAttention(WanSelfAttention):
163
+ def forward(self, x, context, context_lens):
164
+ r"""
165
+ Args non-stream mode:
166
+ x(Tensor): Shape [B, L1, C]
167
+ context(Tensor): Shape [B, L2, C]
168
+ context_lens(Tensor): Shape [B]
169
+ Args stream mode:
170
+ x(Tensor): Shape [B, L1, C]
171
+ context(Tensor): Shape [BxL1, L2, C]
172
+ context_lens(Tensor): Shape [BxL1]
173
+ """
174
+ out_sizes = x.size()
175
+ b, n, d = context.size(0), self.num_heads, self.head_dim
176
+
177
+ # compute query, key, value
178
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
179
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
180
+ v = self.v(context).view(b, -1, n, d)
181
+
182
+ # compute attention
183
+ x = flash_attention(q, k, v, k_lens=context_lens)
184
+
185
+ # output
186
+ x = x.flatten(2).view(*out_sizes)
187
+ x = self.o(x)
188
+ return x
189
+
190
+
191
+ class WanAttentionBlock(nn.Module):
192
+ def __init__(
193
+ self,
194
+ dim,
195
+ ffn_dim,
196
+ num_heads,
197
+ window_size=(-1, -1),
198
+ qk_norm=True,
199
+ cross_attn_norm=False,
200
+ eps=1e-6,
201
+ causal=False,
202
+ ):
203
+ super().__init__()
204
+ self.dim = dim
205
+ self.ffn_dim = ffn_dim
206
+ self.num_heads = num_heads
207
+ self.window_size = window_size
208
+ self.qk_norm = qk_norm
209
+ self.cross_attn_norm = cross_attn_norm
210
+ self.eps = eps
211
+ self.causal = causal
212
+ # layers
213
+ self.norm1 = WanLayerNorm(dim, eps)
214
+ self.self_attn = WanSelfAttention(
215
+ dim, num_heads, window_size, qk_norm, eps, causal
216
+ )
217
+ self.norm3 = (
218
+ WanLayerNorm(dim, eps, elementwise_affine=True)
219
+ if cross_attn_norm
220
+ else nn.Identity()
221
+ )
222
+
223
+ self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps)
224
+ self.norm2 = WanLayerNorm(dim, eps)
225
+ self.ffn = nn.Sequential(
226
+ nn.Linear(dim, ffn_dim),
227
+ nn.GELU(approximate="tanh"),
228
+ nn.Linear(ffn_dim, dim),
229
+ )
230
+
231
+ # modulation
232
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
233
+
234
+ def forward(
235
+ self,
236
+ x,
237
+ e,
238
+ seq_lens,
239
+ grid_sizes,
240
+ freqs,
241
+ context,
242
+ context_lens,
243
+ ):
244
+ r"""
245
+ Args:
246
+ x(Tensor): Shape [B, L, C]
247
+ e(Tensor): Shape [B, L1, 6, C]
248
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
249
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
250
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
251
+ """
252
+ assert e.dtype == torch.float32
253
+ with torch.amp.autocast("cuda", dtype=torch.float32):
254
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
255
+ assert e[0].dtype == torch.float32
256
+
257
+ # self-attention
258
+ y = self.self_attn(
259
+ self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
260
+ seq_lens,
261
+ grid_sizes,
262
+ freqs,
263
+ )
264
+ with torch.amp.autocast("cuda", dtype=torch.float32):
265
+ x = x + y * e[2].squeeze(2)
266
+
267
+ # cross-attention & ffn function
268
+ def cross_attn_ffn(x, context, context_lens, e):
269
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
270
+ y = self.ffn(
271
+ self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)
272
+ )
273
+ with torch.amp.autocast("cuda", dtype=torch.float32):
274
+ x = x + y * e[5].squeeze(2)
275
+ return x
276
+
277
+ x = cross_attn_ffn(x, context, context_lens, e)
278
+ return x
279
+
280
+
281
+ class Head(nn.Module):
282
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
283
+ super().__init__()
284
+ self.dim = dim
285
+ self.out_dim = out_dim
286
+ self.patch_size = patch_size
287
+ self.eps = eps
288
+
289
+ # layers
290
+ out_dim = math.prod(patch_size) * out_dim
291
+ self.norm = WanLayerNorm(dim, eps)
292
+ self.head = nn.Linear(dim, out_dim)
293
+
294
+ # modulation
295
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
296
+
297
+ def forward(self, x, e):
298
+ r"""
299
+ Args:
300
+ x(Tensor): Shape [B, L1, C]
301
+ e(Tensor): Shape [B, L1, C]
302
+ """
303
+ assert e.dtype == torch.float32
304
+ with torch.amp.autocast("cuda", dtype=torch.float32):
305
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
306
+ x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2))
307
+ return x
308
+
309
+
310
+ class WanModel(ModelMixin, ConfigMixin):
311
+ r"""
312
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
313
+ """
314
+
315
+ ignore_for_config = [
316
+ "patch_size",
317
+ "cross_attn_norm",
318
+ "qk_norm",
319
+ "text_dim",
320
+ "window_size",
321
+ ]
322
+ _no_split_modules = ["WanAttentionBlock"]
323
+
324
+ @register_to_config
325
+ def __init__(
326
+ self,
327
+ model_type="t2v",
328
+ patch_size=(1, 2, 2),
329
+ text_len=512,
330
+ in_dim=16,
331
+ dim=2048,
332
+ ffn_dim=8192,
333
+ freq_dim=256,
334
+ text_dim=4096,
335
+ out_dim=16,
336
+ num_heads=16,
337
+ num_layers=32,
338
+ window_size=(-1, -1),
339
+ qk_norm=True,
340
+ cross_attn_norm=True,
341
+ eps=1e-6,
342
+ causal=False,
343
+ ):
344
+ r"""
345
+ Initialize the diffusion model backbone.
346
+
347
+ Args:
348
+ model_type (`str`, *optional*, defaults to 't2v'):
349
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
350
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
351
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
352
+ text_len (`int`, *optional*, defaults to 512):
353
+ Fixed length for text embeddings
354
+ in_dim (`int`, *optional*, defaults to 16):
355
+ Input video channels (C_in)
356
+ dim (`int`, *optional*, defaults to 2048):
357
+ Hidden dimension of the transformer
358
+ ffn_dim (`int`, *optional*, defaults to 8192):
359
+ Intermediate dimension in feed-forward network
360
+ freq_dim (`int`, *optional*, defaults to 256):
361
+ Dimension for sinusoidal time embeddings
362
+ text_dim (`int`, *optional*, defaults to 4096):
363
+ Input dimension for text embeddings
364
+ out_dim (`int`, *optional*, defaults to 16):
365
+ Output video channels (C_out)
366
+ num_heads (`int`, *optional*, defaults to 16):
367
+ Number of attention heads
368
+ num_layers (`int`, *optional*, defaults to 32):
369
+ Number of transformer blocks
370
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
371
+ Window size for local attention (-1 indicates global attention)
372
+ qk_norm (`bool`, *optional*, defaults to True):
373
+ Enable query/key normalization
374
+ cross_attn_norm (`bool`, *optional*, defaults to False):
375
+ Enable cross-attention normalization
376
+ eps (`float`, *optional*, defaults to 1e-6):
377
+ Epsilon value for normalization layers
378
+ """
379
+
380
+ super().__init__()
381
+
382
+ assert model_type in ["t2v", "i2v", "ti2v", "s2v"]
383
+ self.model_type = model_type
384
+
385
+ self.patch_size = patch_size
386
+ self.text_len = text_len
387
+ self.in_dim = in_dim
388
+ self.dim = dim
389
+ self.ffn_dim = ffn_dim
390
+ self.freq_dim = freq_dim
391
+ self.text_dim = text_dim
392
+ self.out_dim = out_dim
393
+ self.num_heads = num_heads
394
+ self.num_layers = num_layers
395
+ self.window_size = window_size
396
+ self.qk_norm = qk_norm
397
+ self.cross_attn_norm = cross_attn_norm
398
+ self.eps = eps
399
+ self.causal = causal
400
+ # embeddings
401
+ self.patch_embedding = nn.Conv3d(
402
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
403
+ )
404
+ self.text_embedding = nn.Sequential(
405
+ nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)
406
+ )
407
+
408
+ self.time_embedding = nn.Sequential(
409
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
410
+ )
411
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
412
+
413
+ # blocks
414
+ self.blocks = nn.ModuleList(
415
+ [
416
+ WanAttentionBlock(
417
+ dim,
418
+ ffn_dim,
419
+ num_heads,
420
+ window_size,
421
+ qk_norm,
422
+ cross_attn_norm,
423
+ eps,
424
+ causal,
425
+ )
426
+ for _ in range(num_layers)
427
+ ]
428
+ )
429
+
430
+ # head
431
+ self.head = Head(dim, out_dim, patch_size, eps)
432
+
433
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
434
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
435
+ d = dim // num_heads
436
+ self.freqs = torch.cat(
437
+ [
438
+ rope_params(1024, d - 4 * (d // 6)),
439
+ rope_params(1024, 2 * (d // 6)),
440
+ rope_params(1024, 2 * (d // 6)),
441
+ ],
442
+ dim=1,
443
+ )
444
+
445
+ # initialize weights
446
+ self.init_weights()
447
+
448
+ def forward(
449
+ self,
450
+ x,
451
+ t,
452
+ context,
453
+ seq_len,
454
+ y=None,
455
+ ):
456
+ r"""
457
+ Forward pass through the diffusion model
458
+
459
+ Args:
460
+ x (List[Tensor]):
461
+ List of input video tensors, each with shape [C_in, F, H, W]
462
+ t (Tensor):
463
+ Diffusion timesteps tensor of shape [B]
464
+ context (List[Tensor]):
465
+ List of text embeddings each with shape [L, C]
466
+ seq_len (`int`):
467
+ Maximum sequence length for positional encoding
468
+ y (List[Tensor], *optional*):
469
+ Conditional video inputs for image-to-video mode, same shape as x
470
+
471
+ Returns:
472
+ List[Tensor]:
473
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
474
+ """
475
+ if self.model_type == "i2v":
476
+ assert y is not None
477
+ # params
478
+ device = self.patch_embedding.weight.device
479
+ if self.freqs.device != device:
480
+ self.freqs = self.freqs.to(device)
481
+
482
+ if y is not None:
483
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
484
+
485
+ # embeddings
486
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
487
+ grid_sizes = torch.stack(
488
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]
489
+ )
490
+ x = [u.flatten(2).transpose(1, 2) for u in x]
491
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
492
+ assert seq_lens.max() <= seq_len
493
+ x = torch.cat(
494
+ [
495
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
496
+ for u in x
497
+ ]
498
+ )
499
+
500
+ # time embeddings
501
+ if t.dim() == 1: # bs
502
+ t = t.expand(t.size(0), seq_len)
503
+ with torch.amp.autocast("cuda", dtype=torch.float32):
504
+ bt = t.size(0)
505
+ t = t.flatten()
506
+ e = self.time_embedding(
507
+ sinusoidal_embedding_1d(self.freq_dim, t)
508
+ .unflatten(0, (bt, seq_len))
509
+ .float()
510
+ )
511
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
512
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
513
+
514
+ # context
515
+ context_lens = torch.tensor([u.size(0) for u in context], dtype=torch.long)
516
+ context = self.text_embedding(
517
+ torch.stack(
518
+ [
519
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
520
+ for u in context
521
+ ]
522
+ )
523
+ )
524
+
525
+ # arguments
526
+ kwargs = dict(
527
+ e=e0,
528
+ seq_lens=seq_lens,
529
+ grid_sizes=grid_sizes,
530
+ freqs=self.freqs,
531
+ context=context,
532
+ context_lens=context_lens,
533
+ )
534
+
535
+ for block in self.blocks:
536
+ x = block(x, **kwargs)
537
+
538
+ # head
539
+ x = self.head(x, e)
540
+
541
+ # unpatchify
542
+ x = self.unpatchify(x, grid_sizes)
543
+ return [u.float() for u in x]
544
+
545
+ def unpatchify(self, x, grid_sizes):
546
+ r"""
547
+ Reconstruct video tensors from patch embeddings.
548
+
549
+ Args:
550
+ x (List[Tensor]):
551
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
552
+ grid_sizes (Tensor):
553
+ Original spatial-temporal grid dimensions before patching,
554
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
555
+
556
+ Returns:
557
+ List[Tensor]:
558
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
559
+ """
560
+
561
+ c = self.out_dim
562
+ out = []
563
+ for u, v in zip(x, grid_sizes.tolist()):
564
+ u = u[: math.prod(v)].view(*v, *self.patch_size, c)
565
+ u = torch.einsum("fhwpqrc->cfphqwr", u)
566
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
567
+ out.append(u)
568
+ return out
569
+
570
+ def init_weights(self):
571
+ r"""
572
+ Initialize model parameters using Xavier initialization.
573
+ """
574
+
575
+ # basic init
576
+ for m in self.modules():
577
+ if isinstance(m, nn.Linear):
578
+ nn.init.xavier_uniform_(m.weight)
579
+ if m.bias is not None:
580
+ nn.init.zeros_(m.bias)
581
+
582
+ # init embeddings
583
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
584
+ for m in self.text_embedding.modules():
585
+ if isinstance(m, nn.Linear):
586
+ nn.init.normal_(m.weight, std=0.02)
587
+ for m in self.time_embedding.modules():
588
+ if isinstance(m, nn.Linear):
589
+ nn.init.normal_(m.weight, std=0.02)
590
+
591
+ # init output layer
592
+ nn.init.zeros_(self.head.head.weight)
ldf_models/tools/wan_vae_1d.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module uses modified code from Alibaba Wan Team
2
+ # Original source: https://github.com/Wan-Video/Wan2.2
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ # Modified to support 1d features with (B, C, T)
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ CACHE_T = 2
11
+
12
+
13
+ class CausalConv1d(nn.Conv1d):
14
+ """
15
+ Causal 1d convolusion.
16
+ """
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ self._padding = (
21
+ 2 * self.padding[0],
22
+ 0,
23
+ )
24
+ self.padding = (0,)
25
+
26
+ def forward(self, x, cache_x=None):
27
+ padding = list(self._padding)
28
+ if cache_x is not None and self._padding[0] > 0:
29
+ cache_x = cache_x.to(x.device)
30
+ x = torch.cat([cache_x, x], dim=2)
31
+ padding[0] -= cache_x.shape[2]
32
+ x = F.pad(x, padding)
33
+
34
+ return super().forward(x)
35
+
36
+
37
+ class RMS_norm(nn.Module):
38
+ def __init__(self, dim, channel_first=True, bias=False):
39
+ super().__init__()
40
+ broadcastable_dims = (1,)
41
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
42
+
43
+ self.channel_first = channel_first
44
+ self.scale = dim**0.5
45
+ self.gamma = nn.Parameter(torch.ones(shape))
46
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
47
+
48
+ def forward(self, x):
49
+ return (
50
+ F.normalize(x, dim=(1 if self.channel_first else -1))
51
+ * self.scale
52
+ * self.gamma
53
+ + self.bias
54
+ )
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+ def forward(self, x):
59
+ """
60
+ Fix bfloat16 support for nearest neighbor interpolation.
61
+ """
62
+ return super().forward(x.float()).type_as(x)
63
+
64
+
65
+ class Resample(nn.Module):
66
+ def __init__(self, dim, mode):
67
+ assert mode in (
68
+ "upsample1d",
69
+ "downsample1d",
70
+ )
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == "upsample1d":
77
+ self.time_conv = CausalConv1d(dim, dim * 2, (3,), padding=(1,))
78
+ elif mode == "downsample1d":
79
+ self.time_conv = CausalConv1d(dim, dim, (3,), stride=(2,), padding=(0,))
80
+
81
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
82
+ b, c, t = x.size()
83
+ if self.mode == "upsample1d":
84
+ if feat_cache is not None:
85
+ idx = feat_idx[0]
86
+ if feat_cache[idx] is None:
87
+ feat_cache[idx] = "Rep"
88
+ feat_idx[0] += 1
89
+ else:
90
+ cache_x = x[:, :, -CACHE_T:].clone()
91
+ if (
92
+ cache_x.shape[2] < 2
93
+ and feat_cache[idx] is not None
94
+ and feat_cache[idx] != "Rep"
95
+ ):
96
+ # cache last frame of last two chunk
97
+ cache_x = torch.cat(
98
+ [
99
+ feat_cache[idx][:, :, -1]
100
+ .unsqueeze(2)
101
+ .to(cache_x.device),
102
+ cache_x,
103
+ ],
104
+ dim=2,
105
+ )
106
+ if (
107
+ cache_x.shape[2] < 2
108
+ and feat_cache[idx] is not None
109
+ and feat_cache[idx] == "Rep"
110
+ ):
111
+ cache_x = torch.cat(
112
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
113
+ dim=2,
114
+ )
115
+ if feat_cache[idx] == "Rep":
116
+ x = self.time_conv(x)
117
+ else:
118
+ x = self.time_conv(x, feat_cache[idx])
119
+ feat_cache[idx] = cache_x
120
+ feat_idx[0] += 1
121
+ x = x.reshape(b, 2, c, t)
122
+ x = torch.stack((x[:, 0, :, :], x[:, 1, :, :]), 3)
123
+ x = x.reshape(b, c, t * 2)
124
+
125
+ if self.mode == "downsample1d":
126
+ if feat_cache is not None:
127
+ idx = feat_idx[0]
128
+ if feat_cache[idx] is None:
129
+ feat_cache[idx] = x.clone()
130
+ feat_idx[0] += 1
131
+ else:
132
+ cache_x = x[:, :, -1:].clone()
133
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:], x], 2))
134
+ feat_cache[idx] = cache_x
135
+ feat_idx[0] += 1
136
+ return x
137
+
138
+
139
+ class ResidualBlock(nn.Module):
140
+ def __init__(self, in_dim, out_dim, dropout=0.0):
141
+ super().__init__()
142
+ self.in_dim = in_dim
143
+ self.out_dim = out_dim
144
+
145
+ # layers
146
+ self.residual = nn.Sequential(
147
+ RMS_norm(in_dim),
148
+ nn.SiLU(),
149
+ CausalConv1d(in_dim, out_dim, 3, padding=1),
150
+ RMS_norm(out_dim),
151
+ nn.SiLU(),
152
+ nn.Dropout(dropout),
153
+ CausalConv1d(out_dim, out_dim, 3, padding=1),
154
+ )
155
+ self.shortcut = (
156
+ CausalConv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
157
+ )
158
+
159
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
160
+ h = self.shortcut(x)
161
+ for layer in self.residual:
162
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
163
+ idx = feat_idx[0]
164
+ cache_x = x[:, :, -CACHE_T:].clone()
165
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
166
+ # cache last frame of last two chunk
167
+ cache_x = torch.cat(
168
+ [
169
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
170
+ cache_x,
171
+ ],
172
+ dim=2,
173
+ )
174
+ x = layer(x, feat_cache[idx])
175
+ feat_cache[idx] = cache_x
176
+ feat_idx[0] += 1
177
+ else:
178
+ x = layer(x)
179
+ return x + h
180
+
181
+
182
+ class AvgDown1D(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_channels,
186
+ out_channels,
187
+ factor_t,
188
+ ):
189
+ super().__init__()
190
+ self.in_channels = in_channels
191
+ self.out_channels = out_channels
192
+ self.factor_t = factor_t
193
+ self.factor = self.factor_t
194
+
195
+ assert in_channels * self.factor % out_channels == 0
196
+ self.group_size = in_channels * self.factor // out_channels
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
199
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
200
+ pad = (pad_t, 0)
201
+ x = F.pad(x, pad)
202
+ B, C, T = x.shape
203
+ x = x.view(
204
+ B,
205
+ C,
206
+ T // self.factor_t,
207
+ self.factor_t,
208
+ )
209
+ x = x.permute(0, 1, 3, 2).contiguous()
210
+ x = x.view(
211
+ B,
212
+ C * self.factor,
213
+ T // self.factor_t,
214
+ )
215
+ x = x.view(
216
+ B,
217
+ self.out_channels,
218
+ self.group_size,
219
+ T // self.factor_t,
220
+ )
221
+ x = x.mean(dim=2)
222
+ return x
223
+
224
+
225
+ class DupUp1D(nn.Module):
226
+ def __init__(
227
+ self,
228
+ in_channels: int,
229
+ out_channels: int,
230
+ factor_t,
231
+ ):
232
+ super().__init__()
233
+ self.in_channels = in_channels
234
+ self.out_channels = out_channels
235
+
236
+ self.factor_t = factor_t
237
+ self.factor = self.factor_t
238
+
239
+ assert out_channels * self.factor % in_channels == 0
240
+ self.repeats = out_channels * self.factor // in_channels
241
+
242
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
243
+ x = x.repeat_interleave(self.repeats, dim=1)
244
+ x = x.view(
245
+ x.size(0),
246
+ self.out_channels,
247
+ self.factor_t,
248
+ x.size(2),
249
+ )
250
+ x = x.permute(0, 1, 3, 2).contiguous()
251
+ x = x.view(
252
+ x.size(0),
253
+ self.out_channels,
254
+ x.size(2) * self.factor_t,
255
+ )
256
+ if first_chunk:
257
+ x = x[
258
+ :,
259
+ :,
260
+ self.factor_t - 1 :,
261
+ ]
262
+ return x
263
+
264
+
265
+ class Down_ResidualBlock(nn.Module):
266
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False):
267
+ super().__init__()
268
+
269
+ # Shortcut path with downsample
270
+ if temperal_downsample:
271
+ self.avg_shortcut = AvgDown1D(
272
+ in_dim,
273
+ out_dim,
274
+ factor_t=2,
275
+ )
276
+ else:
277
+ self.avg_shortcut = None
278
+
279
+ # Main path with residual blocks and downsample
280
+ downsamples = []
281
+ for _ in range(mult):
282
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
283
+ in_dim = out_dim
284
+
285
+ # Add the final downsample block
286
+ if temperal_downsample:
287
+ downsamples.append(Resample(out_dim, mode="downsample1d"))
288
+
289
+ self.downsamples = nn.Sequential(*downsamples)
290
+
291
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
292
+ x_copy = x.clone()
293
+ for module in self.downsamples:
294
+ x = module(x, feat_cache, feat_idx)
295
+ if self.avg_shortcut is None:
296
+ return x
297
+ else:
298
+ return x + self.avg_shortcut(x_copy)
299
+
300
+
301
+ class Up_ResidualBlock(nn.Module):
302
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False):
303
+ super().__init__()
304
+ # Shortcut path with upsample
305
+ if temperal_upsample:
306
+ self.avg_shortcut = DupUp1D(
307
+ in_dim,
308
+ out_dim,
309
+ factor_t=2,
310
+ )
311
+ else:
312
+ self.avg_shortcut = None
313
+
314
+ # Main path with residual blocks and upsample
315
+ upsamples = []
316
+ for _ in range(mult):
317
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
318
+ in_dim = out_dim
319
+
320
+ # Add the final upsample block
321
+ if temperal_upsample:
322
+ upsamples.append(Resample(out_dim, mode="upsample1d"))
323
+
324
+ self.upsamples = nn.Sequential(*upsamples)
325
+
326
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
327
+ x_main = x.clone()
328
+ for module in self.upsamples:
329
+ x_main = module(x_main, feat_cache, feat_idx)
330
+ if self.avg_shortcut is not None:
331
+ x_shortcut = self.avg_shortcut(x, first_chunk)
332
+ return x_main + x_shortcut
333
+ else:
334
+ return x_main
335
+
336
+
337
+ class Encoder1d(nn.Module):
338
+ def __init__(
339
+ self,
340
+ input_dim,
341
+ dim=128,
342
+ z_dim=4,
343
+ dim_mult=[1, 2, 4, 4],
344
+ num_res_blocks=2,
345
+ temperal_downsample=[True, True, False],
346
+ dropout=0.0,
347
+ ):
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.z_dim = z_dim
351
+ self.dim_mult = dim_mult
352
+ self.num_res_blocks = num_res_blocks
353
+ self.temperal_downsample = temperal_downsample
354
+
355
+ # dimensions
356
+ dims = [dim * u for u in [1] + dim_mult]
357
+ scale = 1.0
358
+
359
+ # init block
360
+ self.conv1 = CausalConv1d(input_dim, dims[0], 3, padding=1)
361
+
362
+ # downsample blocks
363
+ downsamples = []
364
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
365
+ t_down_flag = (
366
+ temperal_downsample[i] if i < len(temperal_downsample) else False
367
+ )
368
+ downsamples.append(
369
+ Down_ResidualBlock(
370
+ in_dim=in_dim,
371
+ out_dim=out_dim,
372
+ dropout=dropout,
373
+ mult=num_res_blocks,
374
+ temperal_downsample=t_down_flag,
375
+ )
376
+ )
377
+ scale /= 2.0
378
+ self.downsamples = nn.Sequential(*downsamples)
379
+
380
+ # middle blocks
381
+ self.middle = nn.Sequential(
382
+ ResidualBlock(out_dim, out_dim, dropout),
383
+ RMS_norm(out_dim),
384
+ CausalConv1d(out_dim, out_dim, 1),
385
+ ResidualBlock(out_dim, out_dim, dropout),
386
+ )
387
+
388
+ # # output blocks
389
+ self.head = nn.Sequential(
390
+ RMS_norm(out_dim),
391
+ nn.SiLU(),
392
+ CausalConv1d(out_dim, z_dim, 3, padding=1),
393
+ )
394
+
395
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
396
+ if feat_cache is not None:
397
+ idx = feat_idx[0]
398
+ cache_x = x[:, :, -CACHE_T:].clone()
399
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
400
+ cache_x = torch.cat(
401
+ [
402
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
403
+ cache_x,
404
+ ],
405
+ dim=2,
406
+ )
407
+ x = self.conv1(x, feat_cache[idx])
408
+ feat_cache[idx] = cache_x
409
+ feat_idx[0] += 1
410
+ else:
411
+ x = self.conv1(x)
412
+
413
+ ## downsamples
414
+ for layer in self.downsamples:
415
+ if feat_cache is not None:
416
+ x = layer(x, feat_cache, feat_idx)
417
+ else:
418
+ x = layer(x)
419
+
420
+ ## middle
421
+ for layer in self.middle:
422
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
423
+ x = layer(x, feat_cache, feat_idx)
424
+ else:
425
+ x = layer(x)
426
+
427
+ ## head
428
+ for layer in self.head:
429
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
430
+ idx = feat_idx[0]
431
+ cache_x = x[:, :, -CACHE_T:].clone()
432
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
433
+ cache_x = torch.cat(
434
+ [
435
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
436
+ cache_x,
437
+ ],
438
+ dim=2,
439
+ )
440
+ x = layer(x, feat_cache[idx])
441
+ feat_cache[idx] = cache_x
442
+ feat_idx[0] += 1
443
+ else:
444
+ x = layer(x)
445
+
446
+ return x
447
+
448
+
449
+ class Decoder1d(nn.Module):
450
+ def __init__(
451
+ self,
452
+ output_dim,
453
+ dim=128,
454
+ z_dim=4,
455
+ dim_mult=[1, 2, 4, 4],
456
+ num_res_blocks=2,
457
+ temperal_upsample=[False, True, True],
458
+ dropout=0.0,
459
+ ):
460
+ super().__init__()
461
+ self.dim = dim
462
+ self.z_dim = z_dim
463
+ self.dim_mult = dim_mult
464
+ self.num_res_blocks = num_res_blocks
465
+ self.temperal_upsample = temperal_upsample
466
+
467
+ # dimensions
468
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
469
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
470
+ # init block
471
+ self.conv1 = CausalConv1d(z_dim, dims[0], 3, padding=1)
472
+
473
+ # middle blocks
474
+ self.middle = nn.Sequential(
475
+ ResidualBlock(dims[0], dims[0], dropout),
476
+ RMS_norm(dims[0]),
477
+ CausalConv1d(dims[0], dims[0], 1),
478
+ ResidualBlock(dims[0], dims[0], dropout),
479
+ )
480
+
481
+ # upsample blocks
482
+ upsamples = []
483
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
484
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
485
+ upsamples.append(
486
+ Up_ResidualBlock(
487
+ in_dim=in_dim,
488
+ out_dim=out_dim,
489
+ dropout=dropout,
490
+ mult=num_res_blocks + 1,
491
+ temperal_upsample=t_up_flag,
492
+ )
493
+ )
494
+ self.upsamples = nn.Sequential(*upsamples)
495
+
496
+ # output blocks
497
+ self.head = nn.Sequential(
498
+ RMS_norm(out_dim),
499
+ nn.SiLU(),
500
+ CausalConv1d(out_dim, output_dim, 3, padding=1),
501
+ )
502
+
503
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
504
+ if feat_cache is not None:
505
+ idx = feat_idx[0]
506
+ cache_x = x[:, :, -CACHE_T:].clone()
507
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
508
+ cache_x = torch.cat(
509
+ [
510
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
511
+ cache_x,
512
+ ],
513
+ dim=2,
514
+ )
515
+ x = self.conv1(x, feat_cache[idx])
516
+ feat_cache[idx] = cache_x
517
+ feat_idx[0] += 1
518
+ else:
519
+ x = self.conv1(x)
520
+
521
+ for layer in self.middle:
522
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
523
+ x = layer(x, feat_cache, feat_idx)
524
+ else:
525
+ x = layer(x)
526
+
527
+ ## upsamples
528
+ for layer in self.upsamples:
529
+ if feat_cache is not None:
530
+ x = layer(x, feat_cache, feat_idx, first_chunk)
531
+ else:
532
+ x = layer(x)
533
+
534
+ ## head
535
+ for layer in self.head:
536
+ if isinstance(layer, CausalConv1d) and feat_cache is not None:
537
+ idx = feat_idx[0]
538
+ cache_x = x[:, :, -CACHE_T:].clone()
539
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
540
+ cache_x = torch.cat(
541
+ [
542
+ feat_cache[idx][:, :, -1].unsqueeze(2).to(cache_x.device),
543
+ cache_x,
544
+ ],
545
+ dim=2,
546
+ )
547
+ x = layer(x, feat_cache[idx])
548
+ feat_cache[idx] = cache_x
549
+ feat_idx[0] += 1
550
+ else:
551
+ x = layer(x)
552
+ return x
553
+
554
+
555
+ def count_conv1d(model):
556
+ count = 0
557
+ for m in model.modules():
558
+ if isinstance(m, CausalConv1d):
559
+ count += 1
560
+ return count
561
+
562
+
563
+ class WanVAE_(nn.Module):
564
+ def __init__(
565
+ self,
566
+ input_dim,
567
+ dim=160,
568
+ dec_dim=256,
569
+ z_dim=16,
570
+ dim_mult=[1, 2, 4, 4],
571
+ num_res_blocks=1,
572
+ temperal_downsample=[True, True, False],
573
+ dropout=0.0,
574
+ ):
575
+ super().__init__()
576
+ self.dim = dim
577
+ self.z_dim = z_dim
578
+ self.dim_mult = dim_mult
579
+ self.num_res_blocks = num_res_blocks
580
+ self.temperal_downsample = temperal_downsample
581
+ self.temperal_upsample = temperal_downsample[::-1]
582
+
583
+ # modules
584
+ self.encoder = Encoder1d(
585
+ input_dim,
586
+ dim,
587
+ z_dim * 2,
588
+ dim_mult,
589
+ num_res_blocks,
590
+ self.temperal_downsample,
591
+ dropout,
592
+ )
593
+ self.conv1 = CausalConv1d(z_dim * 2, z_dim * 2, 1)
594
+ self.conv2 = CausalConv1d(z_dim, z_dim, 1)
595
+ self.decoder = Decoder1d(
596
+ input_dim,
597
+ dec_dim,
598
+ z_dim,
599
+ dim_mult,
600
+ num_res_blocks,
601
+ self.temperal_upsample,
602
+ dropout,
603
+ )
604
+
605
+ def forward(self, x, scale=[0, 1]):
606
+ mu = self.encode(x, scale)
607
+ x_recon = self.decode(mu, scale)
608
+ return x_recon, mu
609
+
610
+ def encode(self, x, scale, return_dist=False):
611
+ self.clear_cache()
612
+ t = x.shape[2]
613
+ iter_ = 1 + (t - 1) // 4
614
+ for i in range(iter_):
615
+ self._enc_conv_idx = [0]
616
+ if i == 0:
617
+ out = self.encoder(
618
+ x[:, :, :1],
619
+ feat_cache=self._enc_feat_map,
620
+ feat_idx=self._enc_conv_idx,
621
+ )
622
+ else:
623
+ out_ = self.encoder(
624
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
625
+ feat_cache=self._enc_feat_map,
626
+ feat_idx=self._enc_conv_idx,
627
+ )
628
+ out = torch.cat([out, out_], 2)
629
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
630
+ if isinstance(scale[0], torch.Tensor):
631
+ mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
632
+ 1, self.z_dim, 1
633
+ )
634
+ else:
635
+ mu = (mu - scale[0]) * scale[1]
636
+ self.clear_cache()
637
+ if return_dist:
638
+ return mu, log_var
639
+ return mu
640
+
641
+ def decode(self, z, scale):
642
+ self.clear_cache()
643
+ if isinstance(scale[0], torch.Tensor):
644
+ z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
645
+ else:
646
+ z = z / scale[1] + scale[0]
647
+ iter_ = z.shape[2]
648
+ x = self.conv2(z)
649
+ for i in range(iter_):
650
+ self._conv_idx = [0]
651
+ if i == 0:
652
+ out = self.decoder(
653
+ x[:, :, i : i + 1],
654
+ feat_cache=self._feat_map,
655
+ feat_idx=self._conv_idx,
656
+ first_chunk=True,
657
+ )
658
+ else:
659
+ out_ = self.decoder(
660
+ x[:, :, i : i + 1],
661
+ feat_cache=self._feat_map,
662
+ feat_idx=self._conv_idx,
663
+ )
664
+ out = torch.cat([out, out_], 2)
665
+ self.clear_cache()
666
+ return out
667
+
668
+ @torch.no_grad()
669
+ def stream_encode(self, x, first_chunk, scale, return_dist=False):
670
+ t = x.shape[2]
671
+ if first_chunk:
672
+ iter_ = 1 + (t - 1) // 4
673
+ else:
674
+ iter_ = t // 4
675
+ for i in range(iter_):
676
+ self._enc_conv_idx = [0]
677
+ if i == 0:
678
+ if first_chunk:
679
+ out = self.encoder(
680
+ x[:, :, :1],
681
+ feat_cache=self._enc_feat_map,
682
+ feat_idx=self._enc_conv_idx,
683
+ )
684
+ else:
685
+ out = self.encoder(
686
+ x[:, :, :4],
687
+ feat_cache=self._enc_feat_map,
688
+ feat_idx=self._enc_conv_idx,
689
+ )
690
+ else:
691
+ if first_chunk:
692
+ out_ = self.encoder(
693
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i],
694
+ feat_cache=self._enc_feat_map,
695
+ feat_idx=self._enc_conv_idx,
696
+ )
697
+ else:
698
+ out_ = self.encoder(
699
+ x[:, :, 4 * i : 4 * (i + 1)],
700
+ feat_cache=self._enc_feat_map,
701
+ feat_idx=self._enc_conv_idx,
702
+ )
703
+ out = torch.cat([out, out_], 2)
704
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
705
+ if isinstance(scale[0], torch.Tensor):
706
+ mu = (mu - scale[0].view(1, self.z_dim, 1)) * scale[1].view(
707
+ 1, self.z_dim, 1
708
+ )
709
+ else:
710
+ mu = (mu - scale[0]) * scale[1]
711
+ if return_dist:
712
+ return mu, log_var
713
+ else:
714
+ return mu
715
+
716
+ @torch.no_grad()
717
+ def stream_decode(self, z, first_chunk, scale):
718
+ if isinstance(scale[0], torch.Tensor):
719
+ z = z / scale[1].view(1, self.z_dim, 1) + scale[0].view(1, self.z_dim, 1)
720
+ else:
721
+ z = z / scale[1] + scale[0]
722
+ iter_ = z.shape[2]
723
+ x = self.conv2(z)
724
+ for i in range(iter_):
725
+ self._conv_idx = [0]
726
+ if i == 0:
727
+ out = self.decoder(
728
+ x[:, :, i : i + 1],
729
+ feat_cache=self._feat_map,
730
+ feat_idx=self._conv_idx,
731
+ first_chunk=first_chunk, # Use the external first_chunk parameter
732
+ )
733
+ else:
734
+ out_ = self.decoder(
735
+ x[:, :, i : i + 1],
736
+ feat_cache=self._feat_map,
737
+ feat_idx=self._conv_idx,
738
+ first_chunk=False, # Explicitly set to False for subsequent time steps within the same chunk
739
+ )
740
+ out = torch.cat([out, out_], 2)
741
+ return out
742
+
743
+ def reparameterize(self, mu, log_var):
744
+ std = torch.exp(0.5 * log_var)
745
+ eps = torch.randn_like(std)
746
+ return eps * std + mu
747
+
748
+ def sample(self, imgs, deterministic=False):
749
+ mu, log_var = self.encode(imgs)
750
+ if deterministic:
751
+ return mu
752
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
753
+ return mu + std * torch.randn_like(std)
754
+
755
+ def clear_cache(self):
756
+ self._conv_num = count_conv1d(self.decoder)
757
+ self._conv_idx = [0]
758
+ self._feat_map = [None] * self._conv_num
759
+ # cache encode
760
+ self._enc_conv_num = count_conv1d(self.encoder)
761
+ self._enc_conv_idx = [0]
762
+ self._enc_feat_map = [None] * self._enc_conv_num
ldf_models/vae_wan_1d.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .tools.wan_vae_1d import WanVAE_
6
+
7
+
8
+ class VAEWanModel(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim,
12
+ mean_path=None,
13
+ std_path=None,
14
+ z_dim=256,
15
+ dim=160,
16
+ dec_dim=512,
17
+ num_res_blocks=1,
18
+ dropout=0.0,
19
+ dim_mult=[1, 1, 1],
20
+ temperal_downsample=[True, True],
21
+ vel_window=[0, 0],
22
+ **kwargs,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.mean_path = mean_path
27
+ self.std_path = std_path
28
+ self.input_dim = input_dim
29
+ self.z_dim = z_dim
30
+ self.dim = dim
31
+ self.dec_dim = dec_dim
32
+ self.num_res_blocks = num_res_blocks
33
+ self.dropout = dropout
34
+ self.dim_mult = dim_mult
35
+ self.temperal_downsample = temperal_downsample
36
+ self.vel_window = vel_window
37
+ self.RECONS_LOSS = nn.SmoothL1Loss()
38
+ self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
39
+ self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
40
+ self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)
41
+
42
+ if self.mean_path is not None:
43
+ self.register_buffer(
44
+ "mean", torch.from_numpy(np.load(self.mean_path)).float()
45
+ )
46
+ else:
47
+ self.register_buffer("mean", torch.zeros(input_dim))
48
+
49
+ if self.std_path is not None:
50
+ self.register_buffer(
51
+ "std", torch.from_numpy(np.load(self.std_path)).float()
52
+ )
53
+ else:
54
+ self.register_buffer("std", torch.ones(input_dim))
55
+
56
+ self.model = WanVAE_(
57
+ input_dim=self.input_dim,
58
+ dim=self.dim,
59
+ dec_dim=self.dec_dim,
60
+ z_dim=self.z_dim,
61
+ dim_mult=self.dim_mult,
62
+ num_res_blocks=self.num_res_blocks,
63
+ temperal_downsample=self.temperal_downsample,
64
+ dropout=self.dropout,
65
+ )
66
+
67
+ downsample_factor = 1
68
+ for flag in self.temperal_downsample:
69
+ if flag:
70
+ downsample_factor *= 2
71
+ self.downsample_factor = downsample_factor
72
+
73
+ def preprocess(self, x):
74
+ # (bs, T, C) -> (bs, C, T)
75
+ x = x.permute(0, 2, 1)
76
+ return x
77
+
78
+ def postprocess(self, x):
79
+ # (bs, C, T) -> (bs, T, C)
80
+ x = x.permute(0, 2, 1)
81
+ return x
82
+
83
+ def forward(self, x):
84
+ features = x["feature"]
85
+ feature_length = x["feature_length"]
86
+ features = (features - self.mean) / self.std
87
+ # create mask based on feature_length
88
+ batch_size, seq_len = features.shape[:2]
89
+ mask = torch.zeros(
90
+ batch_size, seq_len, dtype=torch.bool, device=features.device
91
+ )
92
+ for i in range(batch_size):
93
+ mask[i, : feature_length[i]] = True
94
+
95
+ x_in = self.preprocess(features) # (bs, input_dim, T)
96
+ mu, log_var = self.model.encode(
97
+ x_in, scale=[0, 1], return_dist=True
98
+ ) # (bs, z_dim, T)
99
+ z = self.model.reparameterize(mu, log_var)
100
+ x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T)
101
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
102
+
103
+ if x_out.size(1) != features.size(1):
104
+ min_len = min(x_out.size(1), features.size(1))
105
+ x_out = x_out[:, :min_len, :]
106
+ features = features[:, :min_len, :]
107
+ mask = mask[:, :min_len]
108
+
109
+ mask_expanded = mask.unsqueeze(-1)
110
+ x_out_masked = x_out * mask_expanded
111
+ features_masked = features * mask_expanded
112
+ loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
113
+ vel_start = self.vel_window[0]
114
+ vel_end = self.vel_window[1]
115
+ loss_vel = self.RECONS_LOSS(
116
+ x_out_masked[..., vel_start:vel_end],
117
+ features_masked[..., vel_start:vel_end],
118
+ )
119
+
120
+ # Compute KL divergence loss
121
+ # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
122
+ # log_var = log(sigma^2), so we can use it directly
123
+
124
+ # Build mask for latent space
125
+ T_latent = mu.size(2)
126
+ mask_downsampled = torch.zeros(
127
+ batch_size, T_latent, dtype=torch.bool, device=features.device
128
+ )
129
+ for i in range(batch_size):
130
+ latent_length = (
131
+ feature_length[i] + self.downsample_factor - 1
132
+ ) // self.downsample_factor
133
+ mask_downsampled[i, :latent_length] = True
134
+ mask_latent = mask_downsampled.unsqueeze(1) # (B, 1, T_latent)
135
+
136
+ # Compute KL loss per element
137
+ kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
138
+ # Apply mask: only compute KL loss for valid timesteps
139
+ kl_masked = kl_per_element * mask_latent
140
+ # Sum over all dimensions and normalize by the number of valid elements
141
+ kl_loss = torch.sum(kl_masked) / (
142
+ torch.sum(mask_downsampled) * mu.size(1)
143
+ ) # normalize by valid timesteps * latent_dim
144
+
145
+ # Total loss
146
+ total_loss = (
147
+ self.LAMBDA_FEATURE * loss_recons
148
+ + self.LAMBDA_VELOCITY * loss_vel
149
+ + self.LAMBDA_KL * kl_loss
150
+ )
151
+
152
+ loss_dict = {}
153
+ loss_dict["total"] = total_loss
154
+ loss_dict["recons"] = loss_recons
155
+ loss_dict["velocity"] = loss_vel
156
+ loss_dict["kl"] = kl_loss
157
+
158
+ return loss_dict
159
+
160
+ def encode(self, x):
161
+ x = (x - self.mean) / self.std
162
+ x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T)
163
+ mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T)
164
+ mu = self.postprocess(mu) # (bs, T, z_dim)
165
+ return mu
166
+
167
+ def decode(self, mu):
168
+ mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T)
169
+ x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T)
170
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
171
+ x_out = x_out * self.std + self.mean
172
+ return x_out
173
+
174
+ @torch.no_grad()
175
+ def stream_encode(self, x, first_chunk=True):
176
+ x = (x - self.mean) / self.std
177
+ x_in = self.preprocess(x) # (bs, input_dim, T)
178
+ mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
179
+ mu = self.postprocess(mu) # (bs, T, z_dim)
180
+ return mu
181
+
182
+ @torch.no_grad()
183
+ def stream_decode(self, mu, first_chunk=True):
184
+ mu_in = self.preprocess(mu) # (bs, z_dim, T)
185
+ x_decoder = self.model.stream_decode(
186
+ mu_in, first_chunk=first_chunk, scale=[0, 1]
187
+ )
188
+ x_out = self.postprocess(x_decoder) # (bs, T, input_dim)
189
+ x_out = x_out * self.std + self.mean
190
+ return x_out
191
+
192
+ def clear_cache(self):
193
+ self.model.clear_cache()
194
+
195
+ def generate(self, x):
196
+ features = x["feature"]
197
+ feature_length = x["feature_length"]
198
+ y_hat = self.decode(self.encode(features))
199
+
200
+ y_hat_out = []
201
+
202
+ for i in range(y_hat.shape[0]):
203
+ # cut off the padding and align lengths
204
+ valid_len = (
205
+ feature_length[i] - 1
206
+ ) // self.downsample_factor * self.downsample_factor + 1
207
+ # Make sure both have the same length (take minimum)
208
+ y_hat_out.append(y_hat[i, :valid_len, :])
209
+
210
+ out = {}
211
+ out["generated"] = y_hat_out
212
+ return out
ldf_utils/__init__.py ADDED
File without changes
ldf_utils/initialize.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import time
5
+ from datetime import datetime
6
+ from importlib import import_module
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ import torch
11
+ from lightning.pytorch.utilities import rank_zero_info
12
+ from omegaconf import OmegaConf
13
+
14
+
15
+ class Config:
16
+ def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
17
+ self.config = OmegaConf.create({})
18
+
19
+ # Load main config if provided
20
+ if config_path:
21
+ self.load_yaml(config_path)
22
+ if override_args:
23
+ self.override_config(override_args)
24
+
25
+ def load_yaml(self, config_path: str):
26
+ """Load YAML configuration file"""
27
+ loaded_config = OmegaConf.load(config_path)
28
+ self.config = OmegaConf.merge(self.config, loaded_config)
29
+
30
+ def override_config(self, override_args: Dict[str, Any]):
31
+ """Handle command line override arguments"""
32
+ dotlist = []
33
+ for key, value in override_args.items():
34
+ # Handle values that might be converted types but should be strings for paths
35
+ # The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
36
+ # or splitting logic is wrong.
37
+ # Using OmegaConf's standard from_dotlist approach is safest.
38
+ # It expects "key=value" strings.
39
+ # We need to be careful about value conversion.
40
+ # Our _convert_value handles basic types.
41
+
42
+ val = self._convert_value(value)
43
+ # If val is a string, we keep it as is.
44
+ # OmegaConf.from_dotlist parses the string again if we pass "key=value".
45
+ # But we can construct a config from dict and merge.
46
+
47
+ # If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
48
+ # However, `update` takes a key and value.
49
+ OmegaConf.update(self.config, key, val)
50
+
51
+ def _convert_value(self, value: str) -> Any:
52
+ """Convert string value to appropriate type"""
53
+ if value.lower() == "true":
54
+ return True
55
+ elif value.lower() == "false":
56
+ return False
57
+ elif value.lower() == "null":
58
+ return None
59
+ try:
60
+ return int(value)
61
+ except ValueError:
62
+ try:
63
+ return float(value)
64
+ except ValueError:
65
+ return value
66
+
67
+ def get(self, key: str, default: Any = None) -> Any:
68
+ """Get configuration value"""
69
+ return OmegaConf.select(self.config, key, default=default)
70
+
71
+ def __getattr__(self, name: str) -> Any:
72
+ """Support dot notation access"""
73
+ return self.config[name]
74
+
75
+ def __getitem__(self, key: str) -> Any:
76
+ """Support dictionary-like access"""
77
+ return self.config[key]
78
+
79
+ def export_config(self, path: str):
80
+ """Export current configuration to file"""
81
+ OmegaConf.save(self.config, path)
82
+
83
+
84
+ def parse_args():
85
+ """Parse command line arguments"""
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument(
88
+ "--config", type=str, required=True, help="Path to config file"
89
+ )
90
+ parser.add_argument(
91
+ "--override", type=str, nargs="+", help="Override config values (key=value)"
92
+ )
93
+ return parser.parse_args()
94
+
95
+
96
+ def load_config(
97
+ config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
98
+ ) -> Config:
99
+ """Load configuration"""
100
+ if config_path is None:
101
+ args = parse_args()
102
+ config_path = args.config
103
+ if args.override:
104
+ override_args = {}
105
+ for override in args.override:
106
+ key, value = override.split("=", 1)
107
+ override_args[key.strip()] = value.strip()
108
+
109
+ return Config(config_path, override_args)
110
+
111
+
112
+ def instantiate(target, cfg=None, hfstyle=False, **init_args):
113
+ module_name, class_name = target.rsplit(".", 1)
114
+ module = import_module(module_name)
115
+ class_ = getattr(module, class_name)
116
+ if cfg is None:
117
+ return class_(**init_args)
118
+ else:
119
+ if hfstyle:
120
+ config_class = class_.config_class
121
+ cfg = config_class(config_obj=cfg)
122
+ return class_(cfg, **init_args)
123
+
124
+
125
+ def get_function(target):
126
+ module_name, function_name = target.rsplit(".", 1)
127
+ module = import_module(module_name)
128
+ function_ = getattr(module, function_name)
129
+ return function_
130
+
131
+
132
+ def save_config_and_codes(config, save_dir):
133
+ os.makedirs(save_dir, exist_ok=True)
134
+ sanity_check_dir = os.path.join(save_dir, "sanity_check")
135
+ os.makedirs(sanity_check_dir, exist_ok=True)
136
+ with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
137
+ OmegaConf.save(config.config, f)
138
+ current_dir = Path.cwd()
139
+ exclude_dir = current_dir / "outputs"
140
+ for py_file in current_dir.rglob("*.py"):
141
+ if exclude_dir in py_file.parents:
142
+ continue
143
+ dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
144
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
145
+ shutil.copy(py_file, dest_path)
146
+
147
+
148
+ def print_model_size(model):
149
+ total_params = sum(p.numel() for p in model.parameters())
150
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
+ rank_zero_info(f"Total parameters: {total_params:,}")
152
+ rank_zero_info(f"Trainable parameters: {trainable_params:,}")
153
+ rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
154
+
155
+
156
+ def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
157
+ """Compare differences between state_dict and parameters"""
158
+ # Get all keys in state_dict
159
+ state_dict_keys = set(state_dict.keys())
160
+
161
+ # Get all keys in named_parameters
162
+ named_params_keys = set(name for name, _ in named_parameters)
163
+
164
+ # Find keys that only exist in state_dict
165
+ only_in_state_dict = state_dict_keys - named_params_keys
166
+
167
+ # Find keys that only exist in named_parameters
168
+ only_in_named_params = named_params_keys - state_dict_keys
169
+
170
+ # Print results
171
+ if only_in_state_dict:
172
+ print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")
173
+
174
+ if only_in_named_params:
175
+ print(
176
+ f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
177
+ )
178
+
179
+ if not only_in_state_dict and not only_in_named_params:
180
+ print("All parameters match between state_dict and named_parameters")
181
+
182
+ # Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
183
+ named_buffers_keys = set(name for name, _ in named_buffers)
184
+ buffers_only = state_dict_keys - named_params_keys - named_buffers_keys
185
+
186
+ if buffers_only:
187
+ print(
188
+ f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
189
+ )
190
+
191
+ print(f"Total state_dict items: {len(state_dict_keys)}")
192
+ print(f"Total named_parameters: {len(named_params_keys)}")
193
+ print(f"Total named_buffers: {len(named_buffers_keys)}")
194
+
195
+
196
+ def _resolve_global_rank() -> int:
197
+ """Resolve the global rank from environment variables."""
198
+ for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
199
+ if key in os.environ:
200
+ try:
201
+ return int(os.environ[key])
202
+ except ValueError:
203
+ continue
204
+ return 0
205
+
206
+
207
+ def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
208
+ """
209
+ Get a synchronized run time across all processes.
210
+
211
+ This function ensures all processes (both in distributed training and multi-process
212
+ scenarios) use the same timestamp for output directories and experiment tracking.
213
+
214
+ Args:
215
+ base_dir: Base directory for output files
216
+ env_key: Environment variable key to cache the run time
217
+
218
+ Returns:
219
+ Synchronized timestamp string in format YYYYMMDD_HHMMSS
220
+ """
221
+ cached = os.environ.get(env_key)
222
+ if cached:
223
+ return cached
224
+
225
+ timestamp_format = "%Y%m%d_%H%M%S"
226
+
227
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
228
+ if torch.distributed.get_rank() == 0:
229
+ run_time = datetime.now().strftime(timestamp_format)
230
+ else:
231
+ run_time = None
232
+ container = [run_time]
233
+ torch.distributed.broadcast_object_list(container, src=0)
234
+ run_time = container[0]
235
+ if run_time is None:
236
+ raise RuntimeError("Failed to synchronize run time across ranks.")
237
+ os.environ[env_key] = run_time
238
+ return run_time
239
+
240
+ os.makedirs(base_dir, exist_ok=True)
241
+ sync_token = (
242
+ os.environ.get("SLURM_JOB_ID")
243
+ or os.environ.get("TORCHELASTIC_RUN_ID")
244
+ or os.environ.get("JOB_ID")
245
+ or "default"
246
+ )
247
+ sync_dir = os.path.join(base_dir, ".run_time_sync")
248
+ os.makedirs(sync_dir, exist_ok=True)
249
+ sync_file = os.path.join(sync_dir, f"{sync_token}.txt")
250
+
251
+ global_rank = _resolve_global_rank()
252
+ if global_rank == 0:
253
+ # Remove the sync file if it exists to avoid stale reads by other ranks
254
+ if os.path.exists(sync_file):
255
+ try:
256
+ os.remove(sync_file)
257
+ except OSError:
258
+ pass
259
+
260
+ run_time = datetime.now().strftime(timestamp_format)
261
+ with open(sync_file, "w", encoding="utf-8") as f:
262
+ f.write(run_time)
263
+ else:
264
+ timeout = time.monotonic() + 1200.0
265
+ while True:
266
+ if os.path.exists(sync_file):
267
+ try:
268
+ with open(sync_file, "r", encoding="utf-8") as f:
269
+ run_time = f.read().strip()
270
+ # Check if the timestamp is fresh (within 60 seconds)
271
+ # This prevents reading a stale timestamp from a previous run
272
+ dt = datetime.strptime(run_time, timestamp_format)
273
+ if abs((datetime.now() - dt).total_seconds()) < 60:
274
+ break
275
+ except (ValueError, OSError):
276
+ # File might be empty or partially written, or format mismatch
277
+ pass
278
+
279
+ if time.monotonic() > timeout:
280
+ raise TimeoutError(
281
+ "Timed out waiting for rank 0 to write synchronized timestamp."
282
+ )
283
+ time.sleep(0.1)
284
+
285
+ os.environ[env_key] = run_time
286
+ return run_time
ldf_utils/math/__init__.py ADDED
File without changes
ldf_utils/math/quaternion.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float64).eps
14
+
15
+ # PyTorch-backed implementations
16
+
17
+
18
+ def qinv(q):
19
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
20
+ mask = torch.ones_like(q)
21
+ mask[..., 1:] = -mask[..., 1:]
22
+ return q * mask
23
+
24
+
25
+ def qinv_np(q):
26
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
27
+ return qinv(torch.from_numpy(q).float()).numpy()
28
+
29
+
30
+ def qnormalize(q):
31
+ assert q.shape[-1] == 4, "q must be a tensor of shape (*, 4)"
32
+ return q / torch.norm(q, dim=-1, keepdim=True)
33
+
34
+
35
+ def qmul(q, r):
36
+ """
37
+ Multiply quaternion(s) q with quaternion(s) r.
38
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
39
+ Returns q*r as a tensor of shape (*, 4).
40
+ """
41
+ assert q.shape[-1] == 4
42
+ assert r.shape[-1] == 4
43
+
44
+ original_shape = q.shape
45
+
46
+ # Compute outer product
47
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
48
+
49
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
50
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
51
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
52
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
53
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
54
+
55
+
56
+ def qrot(q, v):
57
+ """
58
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
59
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
60
+ where * denotes any number of dimensions.
61
+ Returns a tensor of shape (*, 3).
62
+ """
63
+ assert q.shape[-1] == 4
64
+ assert v.shape[-1] == 3
65
+ assert q.shape[:-1] == v.shape[:-1]
66
+
67
+ original_shape = list(v.shape)
68
+ # print(q.shape)
69
+ q = q.contiguous().view(-1, 4)
70
+ v = v.contiguous().view(-1, 3)
71
+
72
+ qvec = q[:, 1:]
73
+ uv = torch.cross(qvec, v, dim=1)
74
+ uuv = torch.cross(qvec, uv, dim=1)
75
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
76
+
77
+
78
+ def qeuler(q, order, epsilon=0, deg=True):
79
+ """
80
+ Convert quaternion(s) q to Euler angles.
81
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
82
+ Returns a tensor of shape (*, 3).
83
+ """
84
+ assert q.shape[-1] == 4
85
+
86
+ original_shape = list(q.shape)
87
+ original_shape[-1] = 3
88
+ q = q.view(-1, 4)
89
+
90
+ q0 = q[:, 0]
91
+ q1 = q[:, 1]
92
+ q2 = q[:, 2]
93
+ q3 = q[:, 3]
94
+
95
+ if order == "xyz":
96
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
97
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
98
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
99
+ elif order == "yzx":
100
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
101
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
102
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
103
+ elif order == "zxy":
104
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
105
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
106
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ elif order == "xzy":
108
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
109
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
110
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
111
+ elif order == "yxz":
112
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
113
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
114
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
115
+ elif order == "zyx":
116
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
117
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
118
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
119
+ else:
120
+ raise
121
+
122
+ if deg:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
124
+ else:
125
+ return torch.stack((x, y, z), dim=1).view(original_shape)
126
+
127
+
128
+ # Numpy-backed implementations
129
+
130
+
131
+ def qmul_np(q, r):
132
+ q = torch.from_numpy(q).contiguous().float()
133
+ r = torch.from_numpy(r).contiguous().float()
134
+ return qmul(q, r).numpy()
135
+
136
+
137
+ def qrot_np(q, v):
138
+ q = torch.from_numpy(q).contiguous().float()
139
+ v = torch.from_numpy(v).contiguous().float()
140
+ return qrot(q, v).numpy()
141
+
142
+
143
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
144
+ if use_gpu:
145
+ q = torch.from_numpy(q).cuda().float()
146
+ return qeuler(q, order, epsilon).cpu().numpy()
147
+ else:
148
+ q = torch.from_numpy(q).contiguous().float()
149
+ return qeuler(q, order, epsilon).numpy()
150
+
151
+
152
+ def qfix(q):
153
+ """
154
+ Enforce quaternion continuity across the time dimension by selecting
155
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
156
+ between two consecutive frames.
157
+
158
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
159
+ Returns a tensor of the same shape.
160
+ """
161
+ assert len(q.shape) == 3
162
+ assert q.shape[-1] == 4
163
+
164
+ result = q.copy()
165
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
166
+ mask = dot_products < 0
167
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
168
+ result[1:][mask] *= -1
169
+ return result
170
+
171
+
172
+ def euler2quat(e, order, deg=True):
173
+ """
174
+ Convert Euler angles to quaternions.
175
+ """
176
+ assert e.shape[-1] == 3
177
+
178
+ original_shape = list(e.shape)
179
+ original_shape[-1] = 4
180
+
181
+ e = e.view(-1, 3)
182
+
183
+ # if euler angles in degrees
184
+ if deg:
185
+ e = e * np.pi / 180.0
186
+
187
+ x = e[:, 0]
188
+ y = e[:, 1]
189
+ z = e[:, 2]
190
+
191
+ rx = torch.stack(
192
+ (torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)),
193
+ dim=1,
194
+ )
195
+ ry = torch.stack(
196
+ (torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)),
197
+ dim=1,
198
+ )
199
+ rz = torch.stack(
200
+ (torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)),
201
+ dim=1,
202
+ )
203
+
204
+ result = None
205
+ for coord in order:
206
+ if coord == "x":
207
+ r = rx
208
+ elif coord == "y":
209
+ r = ry
210
+ elif coord == "z":
211
+ r = rz
212
+ else:
213
+ raise
214
+ if result is None:
215
+ result = r
216
+ else:
217
+ result = qmul(result, r)
218
+
219
+ # Reverse antipodal representation to have a non-negative "w"
220
+ if order in ["xyz", "yzx", "zxy"]:
221
+ result *= -1
222
+
223
+ return result.view(original_shape)
224
+
225
+
226
+ def expmap_to_quaternion(e):
227
+ """
228
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
229
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
230
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
231
+ Returns a tensor of shape (*, 4).
232
+ """
233
+ assert e.shape[-1] == 3
234
+
235
+ original_shape = list(e.shape)
236
+ original_shape[-1] = 4
237
+ e = e.reshape(-1, 3)
238
+
239
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
240
+ w = np.cos(0.5 * theta).reshape(-1, 1)
241
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
242
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
243
+
244
+
245
+ def euler_to_quaternion(e, order):
246
+ """
247
+ Convert Euler angles to quaternions.
248
+ """
249
+ assert e.shape[-1] == 3
250
+
251
+ original_shape = list(e.shape)
252
+ original_shape[-1] = 4
253
+
254
+ e = e.reshape(-1, 3)
255
+
256
+ x = e[:, 0]
257
+ y = e[:, 1]
258
+ z = e[:, 2]
259
+
260
+ rx = np.stack(
261
+ (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1
262
+ )
263
+ ry = np.stack(
264
+ (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1
265
+ )
266
+ rz = np.stack(
267
+ (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1
268
+ )
269
+
270
+ result = None
271
+ for coord in order:
272
+ if coord == "x":
273
+ r = rx
274
+ elif coord == "y":
275
+ r = ry
276
+ elif coord == "z":
277
+ r = rz
278
+ else:
279
+ raise
280
+ if result is None:
281
+ result = r
282
+ else:
283
+ result = qmul_np(result, r)
284
+
285
+ # Reverse antipodal representation to have a non-negative "w"
286
+ if order in ["xyz", "yzx", "zxy"]:
287
+ result *= -1
288
+
289
+ return result.reshape(original_shape)
290
+
291
+
292
+ def quaternion_to_matrix(quaternions):
293
+ """
294
+ Convert rotations given as quaternions to rotation matrices.
295
+ Args:
296
+ quaternions: quaternions with real part first,
297
+ as tensor of shape (..., 4).
298
+ Returns:
299
+ Rotation matrices as tensor of shape (..., 3, 3).
300
+ """
301
+ r, i, j, k = torch.unbind(quaternions, -1)
302
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
303
+
304
+ o = torch.stack(
305
+ (
306
+ 1 - two_s * (j * j + k * k),
307
+ two_s * (i * j - k * r),
308
+ two_s * (i * k + j * r),
309
+ two_s * (i * j + k * r),
310
+ 1 - two_s * (i * i + k * k),
311
+ two_s * (j * k - i * r),
312
+ two_s * (i * k - j * r),
313
+ two_s * (j * k + i * r),
314
+ 1 - two_s * (i * i + j * j),
315
+ ),
316
+ -1,
317
+ )
318
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
319
+
320
+
321
+ def quaternion_to_matrix_np(quaternions):
322
+ q = torch.from_numpy(quaternions).contiguous().float()
323
+ return quaternion_to_matrix(q).numpy()
324
+
325
+
326
+ def quaternion_to_cont6d_np(quaternions):
327
+ rotation_mat = quaternion_to_matrix_np(quaternions)
328
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
329
+ return cont_6d
330
+
331
+
332
+ def quaternion_to_cont6d(quaternions):
333
+ rotation_mat = quaternion_to_matrix(quaternions)
334
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
335
+ return cont_6d
336
+
337
+
338
+ def cont6d_to_matrix(cont6d):
339
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
340
+ x_raw = cont6d[..., 0:3]
341
+ y_raw = cont6d[..., 3:6]
342
+
343
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
344
+ z = torch.cross(x, y_raw, dim=-1)
345
+ z = z / torch.norm(z, dim=-1, keepdim=True)
346
+
347
+ y = torch.cross(z, x, dim=-1)
348
+
349
+ x = x[..., None]
350
+ y = y[..., None]
351
+ z = z[..., None]
352
+
353
+ mat = torch.cat([x, y, z], dim=-1)
354
+ return mat
355
+
356
+
357
+ def cont6d_to_matrix_np(cont6d):
358
+ q = torch.from_numpy(cont6d).contiguous().float()
359
+ return cont6d_to_matrix(q).numpy()
360
+
361
+
362
+ def qpow(q0, t, dtype=torch.float):
363
+ """q0 : tensor of quaternions
364
+ t: tensor of powers
365
+ """
366
+ q0 = qnormalize(q0)
367
+ theta0 = torch.acos(q0[..., 0])
368
+
369
+ # if theta0 is close to zero, add epsilon to avoid NaNs
370
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
371
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
372
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
373
+
374
+ if isinstance(t, torch.Tensor):
375
+ q = torch.zeros(t.shape + q0.shape)
376
+ theta = t.view(-1, 1) * theta0.view(1, -1)
377
+ else: # if t is a number
378
+ q = torch.zeros(q0.shape)
379
+ theta = t * theta0
380
+
381
+ q[..., 0] = torch.cos(theta)
382
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
383
+
384
+ return q.to(dtype)
385
+
386
+
387
+ def qslerp(q0, q1, t):
388
+ """
389
+ q0: starting quaternion
390
+ q1: ending quaternion
391
+ t: array of points along the way
392
+
393
+ Returns:
394
+ Tensor of Slerps: t.shape + q0.shape
395
+ """
396
+
397
+ q0 = qnormalize(q0)
398
+ q1 = qnormalize(q1)
399
+ q_ = qpow(qmul(q1, qinv(q0)), t)
400
+
401
+ return qmul(
402
+ q_,
403
+ q0.contiguous()
404
+ .view(torch.Size([1] * len(t.shape)) + q0.shape)
405
+ .expand(t.shape + q0.shape)
406
+ .contiguous(),
407
+ )
408
+
409
+
410
+ def qbetween(v0, v1):
411
+ """
412
+ find the quaternion used to rotate v0 to v1
413
+ """
414
+ assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
415
+ assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
416
+
417
+ v = torch.cross(v0, v1)
418
+ w = torch.sqrt(
419
+ (v0**2).sum(dim=-1, keepdim=True) * (v1**2).sum(dim=-1, keepdim=True)
420
+ ) + (v0 * v1).sum(dim=-1, keepdim=True)
421
+ return qnormalize(torch.cat([w, v], dim=-1))
422
+
423
+
424
+ def qbetween_np(v0, v1):
425
+ """
426
+ find the quaternion used to rotate v0 to v1
427
+ """
428
+ assert v0.shape[-1] == 3, "v0 must be of the shape (*, 3)"
429
+ assert v1.shape[-1] == 3, "v1 must be of the shape (*, 3)"
430
+
431
+ v0 = torch.from_numpy(v0).float()
432
+ v1 = torch.from_numpy(v1).float()
433
+ return qbetween(v0, v1).numpy()
434
+
435
+
436
+ def lerp(p0, p1, t):
437
+ if not isinstance(t, torch.Tensor):
438
+ t = torch.Tensor([t])
439
+
440
+ new_shape = t.shape + p0.shape
441
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
442
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
443
+ p0 = p0.view(new_view_p).expand(new_shape)
444
+ p1 = p1.view(new_view_p).expand(new_shape)
445
+ t = t.view(new_view_t).expand(new_shape)
446
+
447
+ return p0 + t * (p1 - p0)
ldf_utils/motion_process.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from ldf_utils.math.quaternion import *
6
+
7
+ """
8
+ Motion data structure:
9
+ (B: batch size)
10
+ root_rot_velocity (B, seq_len, 1)
11
+ root_linear_velocity (B, seq_len, 2)
12
+ root_y (B, seq_len, 1)
13
+ ric_data (B, seq_len, (joint_num - 1)*3)
14
+ rot_data (B, seq_len, (joint_num - 1)*6)
15
+ local_velocity (B, seq_len, joint_num*3)
16
+ foot contact (B, seq_len, 4)
17
+ """
18
+
19
+
20
+ def recover_root_rot_pos(data):
21
+ # recover root rotation and position
22
+ rot_vel = data[..., 0]
23
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
24
+ """Get Y-axis rotation from rotation velocity"""
25
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
26
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
27
+
28
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
29
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
30
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
31
+
32
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
33
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
34
+ """Add Y-axis rotation to root position"""
35
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
36
+
37
+ r_pos = torch.cumsum(r_pos, dim=-2)
38
+
39
+ r_pos[..., 1] = data[..., 3]
40
+ return r_rot_quat, r_pos
41
+
42
+
43
+ def recover_joint_positions_263(data: np.ndarray, joints_num) -> np.ndarray:
44
+ """
45
+ Recovers 3D joint positions from the rotation-invariant local positions (ric_data).
46
+ This is the most direct way to get the skeleton for animation.
47
+ """
48
+ feature_vec = torch.from_numpy(data).unsqueeze(0).float()
49
+ r_rot_quat, r_pos = recover_root_rot_pos(feature_vec)
50
+ positions = feature_vec[..., 4 : (joints_num - 1) * 3 + 4]
51
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
52
+ """Add Y-axis rotation to local joints"""
53
+ positions = qrot(
54
+ qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions
55
+ )
56
+ """Add root XZ to joints"""
57
+ positions[..., 0] += r_pos[..., 0:1]
58
+ positions[..., 2] += r_pos[..., 2:3]
59
+ """Concatenate root and joints"""
60
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
61
+ joints_np = positions.squeeze(0).detach().cpu().numpy()
62
+ return joints_np
63
+
64
+
65
+ class StreamJointRecovery263:
66
+ """
67
+ Stream version of recover_joint_positions_263 that processes one frame at a time.
68
+ Maintains cumulative state for rotation angles and positions.
69
+
70
+ Key insight: The batch version uses PREVIOUS frame's velocity for the current frame,
71
+ so we need to delay the velocity application by one frame.
72
+
73
+ Args:
74
+ joints_num: Number of joints in the skeleton
75
+ smoothing_alpha: EMA smoothing factor (0.0 to 1.0)
76
+ - 1.0 = no smoothing (default), output follows input exactly
77
+ - 0.0 = infinite smoothing, output never changes
78
+ - Recommended values: 0.3-0.7 for visible smoothing
79
+ - Formula: smoothed = alpha * current + (1 - alpha) * previous
80
+ """
81
+
82
+ def __init__(self, joints_num: int, smoothing_alpha: float = 1.0):
83
+ self.joints_num = joints_num
84
+ self.smoothing_alpha = np.clip(smoothing_alpha, 0.0, 1.0)
85
+ self.reset()
86
+
87
+ def reset(self):
88
+ """Reset the accumulated state"""
89
+ self.r_rot_ang_accum = 0.0
90
+ self.r_pos_accum = np.array([0.0, 0.0, 0.0])
91
+ # Store previous frame's velocities for delayed application
92
+ self.prev_rot_vel = 0.0
93
+ self.prev_linear_vel = np.array([0.0, 0.0])
94
+ # Store previous smoothed joints for EMA
95
+ self.prev_smoothed_joints = None
96
+
97
+ def process_frame(self, frame_data: np.ndarray) -> np.ndarray:
98
+ """
99
+ Process a single frame and return joint positions for that frame.
100
+
101
+ Args:
102
+ frame_data: numpy array of shape (263,) for a single frame
103
+
104
+ Returns:
105
+ joints: numpy array of shape (joints_num, 3) representing joint positions
106
+ """
107
+ # Convert to torch tensor
108
+ feature_vec = torch.from_numpy(frame_data).float()
109
+
110
+ # Extract current frame's velocities (will be used in NEXT frame)
111
+ curr_rot_vel = feature_vec[0].item()
112
+ curr_linear_vel = feature_vec[1:3].numpy()
113
+
114
+ # Update accumulated rotation angle with PREVIOUS frame's velocity FIRST
115
+ # This matches the batch processing: r_rot_ang[i] uses rot_vel[i-1]
116
+ self.r_rot_ang_accum += self.prev_rot_vel
117
+
118
+ # Calculate current rotation quaternion using updated accumulated angle
119
+ r_rot_quat = torch.zeros(4)
120
+ r_rot_quat[0] = np.cos(self.r_rot_ang_accum)
121
+ r_rot_quat[2] = np.sin(self.r_rot_ang_accum)
122
+
123
+ # Create velocity vector with Y=0 using PREVIOUS frame's velocity
124
+ r_vel = np.array([self.prev_linear_vel[0], 0.0, self.prev_linear_vel[1]])
125
+
126
+ # Apply inverse rotation to velocity using CURRENT rotation
127
+ r_vel_torch = torch.from_numpy(r_vel).float()
128
+ r_vel_rotated = qrot(qinv(r_rot_quat).unsqueeze(0), r_vel_torch.unsqueeze(0))
129
+ r_vel_rotated = r_vel_rotated.squeeze(0).numpy()
130
+
131
+ # Update accumulated position with rotated velocity
132
+ self.r_pos_accum += r_vel_rotated
133
+
134
+ # Get Y position from data
135
+ r_pos = self.r_pos_accum.copy()
136
+ r_pos[1] = feature_vec[3].item()
137
+
138
+ # Extract local joint positions
139
+ positions = feature_vec[4 : (self.joints_num - 1) * 3 + 4]
140
+ positions = positions.view(-1, 3)
141
+
142
+ # Apply inverse rotation to local joints
143
+ r_rot_quat_expanded = (
144
+ qinv(r_rot_quat).unsqueeze(0).expand(positions.shape[0], 4)
145
+ )
146
+ positions = qrot(r_rot_quat_expanded, positions)
147
+
148
+ # Add root XZ to joints
149
+ positions[:, 0] += r_pos[0]
150
+ positions[:, 2] += r_pos[2]
151
+
152
+ # Concatenate root and joints
153
+ r_pos_torch = torch.from_numpy(r_pos).float()
154
+ positions = torch.cat([r_pos_torch.unsqueeze(0), positions], dim=0)
155
+
156
+ # Convert to numpy
157
+ joints_np = positions.detach().cpu().numpy()
158
+
159
+ # Apply EMA smoothing if enabled
160
+ if self.smoothing_alpha < 1.0:
161
+ if self.prev_smoothed_joints is None:
162
+ # First frame, no smoothing possible
163
+ self.prev_smoothed_joints = joints_np.copy()
164
+ else:
165
+ # EMA: smoothed = alpha * current + (1 - alpha) * previous
166
+ joints_np = (
167
+ self.smoothing_alpha * joints_np
168
+ + (1.0 - self.smoothing_alpha) * self.prev_smoothed_joints
169
+ )
170
+ self.prev_smoothed_joints = joints_np.copy()
171
+
172
+ # Store current velocities for next frame
173
+ self.prev_rot_vel = curr_rot_vel
174
+ self.prev_linear_vel = curr_linear_vel
175
+
176
+ return joints_np
177
+
178
+
179
+ def accumulate_rotations(relative_rotations):
180
+ R_total = [relative_rotations[0]]
181
+ for R_rel in relative_rotations[1:]:
182
+ R_total.append(np.matmul(R_rel, R_total[-1]))
183
+
184
+ return np.array(R_total)
185
+
186
+
187
+ def recover_from_local_position(final_x, njoint):
188
+ nfrm, _ = final_x.shape
189
+ positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(
190
+ nfrm, -1, 3
191
+ ) # frames, njoints * 3
192
+ velocities_root_xy_no_heading = final_x[:, :2] # frames, 2
193
+ global_heading_diff_rot = final_x[:, 2:8] # frames, 6
194
+
195
+ # recover global heading
196
+ global_heading_rot = accumulate_rotations(
197
+ rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
198
+ )
199
+ inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
200
+ # add global heading to position
201
+ positions_with_heading = np.matmul(
202
+ np.repeat(inv_global_heading_rot[:, None, :, :], njoint, axis=1),
203
+ positions_no_heading[..., None],
204
+ ).squeeze(-1)
205
+
206
+ # recover root translation
207
+ # add heading to velocities_root_xy_no_heading
208
+
209
+ velocities_root_xyz_no_heading = np.zeros(
210
+ (
211
+ velocities_root_xy_no_heading.shape[0],
212
+ 3,
213
+ )
214
+ )
215
+ velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
216
+ velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
217
+ velocities_root_xyz_no_heading[1:, :] = np.matmul(
218
+ inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
219
+ ).squeeze(-1)
220
+
221
+ root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
222
+
223
+ # add root translation
224
+ positions_with_heading[:, :, 0] += root_translation[:, 0:1]
225
+ positions_with_heading[:, :, 2] += root_translation[:, 2:]
226
+
227
+ return positions_with_heading
228
+
229
+
230
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
231
+ a1, a2 = d6[..., :3], d6[..., 3:]
232
+ b1 = F.normalize(a1, dim=-1)
233
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
234
+ b2 = F.normalize(b2, dim=-1)
235
+ b3 = torch.cross(b1, b2, dim=-1)
236
+ return torch.stack((b1, b2, b3), dim=-2)
237
+
238
+
239
+ def _copysign(a, b):
240
+ signs_differ = (a < 0) != (b < 0)
241
+ return torch.where(signs_differ, -a, a)
242
+
243
+
244
+ def _sqrt_positive_part(x):
245
+ ret = torch.zeros_like(x)
246
+ positive_mask = x > 0
247
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
248
+ return ret
249
+
250
+
251
+ def matrix_to_quaternion(matrix):
252
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
253
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
254
+ m00 = matrix[..., 0, 0]
255
+ m11 = matrix[..., 1, 1]
256
+ m22 = matrix[..., 2, 2]
257
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
258
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
259
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
260
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
261
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
262
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
263
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
264
+ return torch.stack((o0, o1, o2, o3), -1)
265
+
266
+
267
+ def quaternion_to_axis_angle(quaternions):
268
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
269
+ half_angles = torch.atan2(norms, quaternions[..., :1])
270
+ angles = 2 * half_angles
271
+ eps = 1e-6
272
+ small_angles = angles.abs() < eps
273
+ sin_half_angles_over_angles = torch.empty_like(angles)
274
+ sin_half_angles_over_angles[~small_angles] = (
275
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
276
+ )
277
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
278
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
279
+ sin_half_angles_over_angles[small_angles] = (
280
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
281
+ )
282
+ return quaternions[..., 1:] / sin_half_angles_over_angles
283
+
284
+
285
+ def matrix_to_axis_angle(matrix):
286
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
287
+
288
+
289
+ def rotations_matrix_to_smpl85(rotations_matrix, translation):
290
+ nfrm, njoint, _, _ = rotations_matrix.shape
291
+ axis_angle = (
292
+ matrix_to_axis_angle(torch.from_numpy(rotations_matrix))
293
+ .numpy()
294
+ .reshape(nfrm, -1)
295
+ )
296
+ smpl_85 = np.concatenate(
297
+ [axis_angle, np.zeros((nfrm, 6)), translation, np.zeros((nfrm, 10))], axis=-1
298
+ )
299
+ return smpl_85
300
+
301
+
302
+ def recover_from_local_rotation(final_x, njoint):
303
+ nfrm, _ = final_x.shape
304
+ rotations_matrix = rotation_6d_to_matrix(
305
+ torch.from_numpy(final_x[:, 8 + 6 * njoint : 8 + 12 * njoint]).reshape(
306
+ nfrm, -1, 6
307
+ )
308
+ ).numpy()
309
+ global_heading_diff_rot = final_x[:, 2:8]
310
+ velocities_root_xy_no_heading = final_x[:, :2]
311
+ positions_no_heading = final_x[:, 8 : 8 + 3 * njoint].reshape(nfrm, -1, 3)
312
+ height = positions_no_heading[:, 0, 1]
313
+
314
+ global_heading_rot = accumulate_rotations(
315
+ rotation_6d_to_matrix(torch.from_numpy(global_heading_diff_rot)).numpy()
316
+ )
317
+ inv_global_heading_rot = np.transpose(global_heading_rot, (0, 2, 1))
318
+ # recover root rotation
319
+ rotations_matrix[:, 0, ...] = np.matmul(
320
+ inv_global_heading_rot, rotations_matrix[:, 0, ...]
321
+ )
322
+ velocities_root_xyz_no_heading = np.zeros(
323
+ (
324
+ velocities_root_xy_no_heading.shape[0],
325
+ 3,
326
+ )
327
+ )
328
+ velocities_root_xyz_no_heading[:, 0] = velocities_root_xy_no_heading[:, 0]
329
+ velocities_root_xyz_no_heading[:, 2] = velocities_root_xy_no_heading[:, 1]
330
+ velocities_root_xyz_no_heading[1:, :] = np.matmul(
331
+ inv_global_heading_rot[:-1], velocities_root_xyz_no_heading[1:, :, None]
332
+ ).squeeze(-1)
333
+ root_translation = np.cumsum(velocities_root_xyz_no_heading, axis=0)
334
+ root_translation[:, 1] = height
335
+ smpl_85 = rotations_matrix_to_smpl85(rotations_matrix, root_translation)
336
+ return smpl_85
337
+
338
+
339
+ def recover_joint_positions_272(data: np.ndarray, joints_num) -> np.ndarray:
340
+ return recover_from_local_position(data, joints_num)
341
+
342
+
343
+ def convert_motion_to_joints(
344
+ motion_data: np.ndarray,
345
+ dim: int,
346
+ mean: np.ndarray = None,
347
+ std: np.ndarray = None,
348
+ joints_num=22,
349
+ ):
350
+ """
351
+ Convert Kx263 dim or Kx272 dim motion data to Kx22x3 joint positions.
352
+ Args:
353
+ motion_data: numpy array of shape (K, 263) or (K, 272) where K is number of frames
354
+ Returns:
355
+ joints: numpy array of shape (K, 22, 3) representing joint positions
356
+ """
357
+ if mean is not None and std is not None:
358
+ motion_data = motion_data * std + mean
359
+ if dim == 263:
360
+ recovered_positions = recover_joint_positions_263(motion_data, joints_num)
361
+ elif dim == 272:
362
+ recovered_positions = recover_joint_positions_272(motion_data, joints_num)
363
+ else:
364
+ raise ValueError(f"Unsupported motion data dimension: {dim}")
365
+ return recovered_positions
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3528a345e2795f0b28343896515adc2c14746567896c66620852678ff8d43a79
3
+ size 36753080
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ huggingface_hub>=0.16.0
5
+ safetensors>=0.3.0
6
+ diffusers>=0.20.0
7
+
8
+ # Inference
9
+ lightning>=2.0.0
10
+ ftfy
11
+
12
+ # Configuration
13
+ omegaconf
14
+
15
+ # Utilities
16
+ numpy
17
+
18
+ # Note: flash-attn is required but needs special installation
19
+ # See README.md for installation instructions
vae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a40164154c476309ff952a4b7563750b7e76fbdd8d263ec261ad877cf452e7b
3
+ size 70027220