NikhilMarisetty commited on
Commit
eb71a72
·
verified ·
1 Parent(s): 8b783f1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ teaser/teaser.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Project-specific
7
+ experiments
8
+ data/finedance
9
+ generated
10
+ eval
11
+ wandb
12
+ assets/checkpoints/
13
+ assets/smpl_model/
14
+ assets/ffmpeg-6.0-amd64-static/
15
+ assets/NORMAL_new.obj
16
+
17
+ # User-generated output and uploaded music
18
+ output/
19
+ custom_music/
20
+ *.mp4
21
+ *.wav
22
+ *.mp3
23
+ *.flac
24
+ *.ogg
25
+ *.m4a
26
+
27
+ # macOS
28
+ .DS_Store
29
+
30
+ # VSCode
31
+ .vscode/
32
+
33
+ # C extensions
34
+ *.so
35
+
36
+ # Distribution / packaging
37
+ .Python
38
+ build/
39
+ develop-eggs/
40
+ dist/
41
+ downloads/
42
+ eggs/
43
+ .eggs/
44
+ lib/
45
+ lib64/
46
+ parts/
47
+ sdist/
48
+ var/
49
+ wheels/
50
+ share/python-wheels/
51
+ *.egg-info/
52
+ .installed.cfg
53
+ *.egg
54
+ MANIFEST
55
+
56
+ # PyInstaller
57
+ # Usually these files are written by a python script from a template
58
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
59
+ *.manifest
60
+ *.spec
61
+
62
+ # Installer logs
63
+ pip-log.txt
64
+ pip-delete-this-directory.txt
65
+
66
+ # Unit test / coverage reports
67
+ htmlcov/
68
+ .tox/
69
+ .nox/
70
+ .coverage
71
+ .coverage.*
72
+ .cache
73
+ nosetests.xml
74
+ coverage.xml
75
+ *.cover
76
+ *.py,cover
77
+ .hypothesis/
78
+ .pytest_cache/
79
+ cover/
80
+
81
+ # Translations
82
+ *.mo
83
+ *.pot
84
+
85
+ # Django stuff:
86
+ *.log
87
+ local_settings.py
88
+ db.sqlite3
89
+ db.sqlite3-journal
90
+
91
+ # Flask stuff:
92
+ instance/
93
+ .webassets-cache
94
+
95
+ # Scrapy stuff:
96
+ .scrapy
97
+
98
+ # Sphinx documentation
99
+ docs/_build/
100
+
101
+ # PyBuilder
102
+ .pybuilder/
103
+ target/
104
+
105
+ # Jupyter Notebook
106
+ .ipynb_checkpoints
107
+
108
+ # IPython
109
+ profile_default/
110
+ ipython_config.py
111
+
112
+ # pyenv
113
+ # For a library or package, you might want to ignore these files since the code is
114
+ # intended to run in multiple environments; otherwise, check them in:
115
+ # .python-version
116
+
117
+ # pipenv
118
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
119
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
120
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
121
+ # install all needed dependencies.
122
+ #Pipfile.lock
123
+
124
+ # poetry
125
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
126
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
127
+ # commonly ignored for libraries.
128
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
129
+ #poetry.lock
130
+
131
+ # pdm
132
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
133
+ #pdm.lock
134
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
135
+ # in version control.
136
+ # https://pdm.fming.dev/#use-with-ide
137
+ .pdm.toml
138
+
139
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
140
+ __pypackages__/
141
+
142
+ # Celery stuff
143
+ celerybeat-schedule
144
+ celerybeat.pid
145
+
146
+ # SageMath parsed files
147
+ *.sage.py
148
+
149
+ # Environments
150
+ .env
151
+ .venv
152
+ env/
153
+ venv/
154
+ ENV/
155
+ env.bak/
156
+ venv.bak/
157
+
158
+ # Spyder project settings
159
+ .spyderproject
160
+ .spyproject
161
+
162
+ # Rope project settings
163
+ .ropeproject
164
+
165
+ # mkdocs documentation
166
+ /site
167
+
168
+ # mypy
169
+ .mypy_cache/
170
+ .dmypy.json
171
+ dmypy.json
172
+
173
+ # Pyre type checker
174
+ .pyre/
175
+
176
+ # pytype static type analyzer
177
+ .pytype/
178
+
179
+ # Cython debug symbols
180
+ cython_debug/
181
+
182
+ # PyCharm
183
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
184
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
185
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
186
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
187
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+ Software Copyright License for non-commercial scientific research purposes
3
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use FineDance data, model and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
4
+
5
+ Ownership / Licensees
6
+ The Software and the associated materials has been developed at the
7
+
8
+ Professor Xiu Li, Tsinghua University.
9
+
10
+ Any copyright or patent right is owned by and proprietary material of the
11
+
12
+ Professor Xiu Li, Tsinghua University.
13
+
14
+ hereinafter the “Licensor”.
15
+
16
+ License Grant
17
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
18
+
19
+ To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
20
+ To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
21
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Xiu Li’s prior written permission.
22
+
23
+ The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
24
+
25
+ No Distribution
26
+ The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
27
+
28
+ Disclaimer of Representations and Warranties
29
+ You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
30
+
31
+ Limitation of Liability
32
+ The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
33
+
34
+ No Maintenance Services
35
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
36
+
37
+ Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
38
+
39
+ Publications using the Data & Software
40
+ You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
41
+
42
+ Citation:
43
+
44
+
45
+ @InProceedings{Li_2023_ICCV,
46
+ author = {Li, Ronghui and Zhao, Junfan and Zhang, Yachao and Su, Mingyang and Ren, Zeping and Zhang, Han and Tang, Yansong and Li, Xiu},
47
+ title = {FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation},
48
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
49
+ month = {October},
50
+ year = {2023},
51
+ pages = {10234-10243}
52
+ }
README.md CHANGED
@@ -1,3 +1,164 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation (ICCV 2023)](https://github.com/li-ronghui/FineDance)
2
+
3
+ [[Project Page](https://li-ronghui.github.io/finedance)] | [[Preprint](https://arxiv.org/abs/2212.03741)] | [[pdf](https://arxiv.org/pdf/2212.03741.pdf)] | [[video](https://li-ronghui.github.io/finedance)]
4
+
5
+ <img src="teaser/teaser.png">
6
+
7
+ ## Quick Start
8
+
9
+ ### Prerequisites
10
+
11
+ Install the conda environment and activate it:
12
+
13
+ ```bash
14
+ conda env create -f environment.yaml
15
+ conda activate FineNet
16
+ ```
17
+
18
+ Download the pretrained checkpoints and asset files from [Google Drive](https://drive.google.com/file/d/1ENoeUn-X-3Vw2Gon-voVLlndy3hZXdWD/view?usp=drive_link).
19
+
20
+ ### Web UI (Recommended)
21
+
22
+ Launch the Gradio web interface:
23
+
24
+ ```bash
25
+ python app.py
26
+ ```
27
+
28
+ Open `http://127.0.0.1:7861` in your browser. Upload a music file and click "Generate Dance" to produce a video.
29
+
30
+ ### Command Line
31
+
32
+ ```bash
33
+ python generate_dance.py /path/to/music.mp3
34
+ ```
35
+
36
+ Output will be saved to `output/<songname>_dance.mp4`.
37
+
38
+ To specify a custom output path:
39
+
40
+ ```bash
41
+ python generate_dance.py /path/to/music.mp3 --output my_dance.mp4
42
+ ```
43
+
44
+ Supported audio formats: any format ffmpeg can read (`.mp3`, `.wav`, `.m4a`, `.flac`, `.ogg`, etc.).
45
+
46
+ ### Output Details
47
+
48
+ - Resolution: 1200x1200
49
+ - Frame rate: 30 fps
50
+ - Duration: ~30 seconds
51
+ - Background: black
52
+ - Body model: SMPLX (full body with hands)
53
+
54
+ ## How It Works
55
+
56
+ 1. **Audio conversion** - Converts input to WAV if needed
57
+ 2. **Feature extraction** - Slices audio into 4-second windows with 2-second stride, then extracts 35-dim features per slice (onset envelope, 20 MFCC, 12 chroma, peak onehot, beat onehot) using librosa
58
+ 3. **Dance generation** - Feeds audio features into a pretrained diffusion model (`assets/checkpoints/train-2000.pt`) which generates SMPLX body motion (319-dim: 4 contact + 3 translation + 52 joints x 6 rotation)
59
+ 4. **Rendering** - Converts generated motion to SMPLX meshes and renders 900 frames at 30fps using pyrender
60
+ 5. **Final output** - Merges rendered video with original audio via ffmpeg
61
+
62
+ ## FineDance Dataset
63
+
64
+ The dataset (7.7 hours) can be downloaded from [Google Drive](https://drive.google.com/file/d/1zQvWG9I0H4U3Zrm8d_QD_ehenZvqfQfS/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1gynUC7pMdpsE31wAwq177w?pwd=o9pw).
65
+
66
+ Put the downloaded data into `./data`. The data directory contains:
67
+
68
+ - **label_json** - Song name, coarse style, and fine-grained genre
69
+ - **motion** - [SMPLH](https://smpl-x.is.tue.mpg.de/) format motion data
70
+ - **music_wav** - Music data in WAV format
71
+ - **music_npy** - Music features extracted by [librosa](https://github.com/librosa/librosa) following [AIST++](https://github.com/google/aistplusplus_api/tree/main)
72
+
73
+ Reading a motion file:
74
+
75
+ ```python
76
+ import numpy as np
77
+ data = np.load("motion/001.npy")
78
+ T, C = data.shape # T is the number of frames
79
+ smpl_poses = data[:, 3:]
80
+ smpl_trans = data[:, :3]
81
+ ```
82
+
83
+ ### Dataset Split
84
+
85
+ The dataset is split into train, val, and test sets in two ways:
86
+
87
+ 1. **FineDance@Genre** - Test set includes a broader range of dance genres; the same dancer may appear across splits but with different motions. Recommended for dance generation.
88
+ 2. **FineDance@Dancer** - Splits are divided by dancer; the same dancer won't appear in different sets, but the test set contains fewer genres.
89
+
90
+ ## Training
91
+
92
+ Only needed if you want to train from scratch. The pretrained checkpoint is already provided.
93
+
94
+ ```bash
95
+ # Data preprocessing
96
+ python data/code/pre_motion.py
97
+
98
+ # Train
99
+ accelerate launch train_seq.py --batch_size 32 --epochs 200
100
+ ```
101
+
102
+ Key flags:
103
+ - `--batch_size` - Default is 400, reduce to 32 or lower for Mac MPS (limited to ~30GB)
104
+ - `--epochs` - Default is 2000
105
+ - `--checkpoint` - Resume from a saved checkpoint
106
+
107
+ ## Advanced Usage
108
+
109
+ ### Generate on the test set
110
+
111
+ ```bash
112
+ python data/code/slice_music_motion.py
113
+ python generate_all.py --motion_save_dir generated/finedance_seq_120_dancer --save_motions
114
+ ```
115
+
116
+ ### Render a pre-generated motion file
117
+
118
+ ```bash
119
+ python render.py --modir eval/motions --mode smplx
120
+ ```
121
+
122
+ ## Project Structure
123
+
124
+ ```
125
+ FineDance/
126
+ ├── app.py # Gradio web UI
127
+ ├── generate_dance.py # One-command dance generation (CLI)
128
+ ├── train_seq.py # Training script
129
+ ├── test.py # Original test/inference script
130
+ ├── render.py # Video rendering (SMPLX mesh to MP4)
131
+ ├── args.py # CLI argument definitions
132
+ ├── vis.py # Skeleton/FK utilities
133
+ ├── assets/
134
+ │ ├── checkpoints/
135
+ │ │ └── train-2000.pt # Pretrained model (2000 epochs)
136
+ │ └── smpl_model/
137
+ │ └── smplx/
138
+ │ └── SMPLX_NEUTRAL.npz # SMPLX body model
139
+ ├── model/
140
+ │ ├── model.py # SeqModel (transformer decoder)
141
+ │ └── diffusion.py # Gaussian diffusion (training + sampling)
142
+ ├── dataset/
143
+ │ └── FineDance_dataset.py # Dataset loader
144
+ └── data/
145
+ └── finedance/ # Training data (music + motion pairs)
146
+ ```
147
+
148
+ ## Acknowledgments
149
+
150
+ We would like to express our sincere gratitude to Dr [Yan Zhang](https://yz-cnsdqz.github.io/) and [Yulun Zhang](https://yulunzhang.com/) for their invaluable guidance and insights during the course of our research.
151
+
152
+ This code is based on: [EDGE](https://github.com/Stanford-TML/EDGE/tree/main), [MDM](https://github.com/Stanford-TML/EDGE/tree/main), [Adan](https://github.com/lucidrains/Adan-pytorch), [Diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch), [SMPLX](https://smpl-x.is.tue.mpg.de/).
153
+
154
+ ## Citation
155
+
156
+ ```
157
+ @inproceedings{li2023finedance,
158
+ title={FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation},
159
+ author={Li, Ronghui and Zhao, Junfan and Zhang, Yachao and Su, Mingyang and Ren, Zeping and Zhang, Han and Tang, Yansong and Li, Xiu},
160
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
161
+ pages={10234--10243},
162
+ year={2023}
163
+ }
164
+ ```
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for FineDance — generate dance videos from music.
3
+
4
+ Usage:
5
+ conda activate FineNet
6
+ python app.py
7
+ """
8
+
9
+ import os
10
+ import tempfile
11
+
12
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
13
+
14
+ import gradio as gr
15
+
16
+ # Monkey-patch gradio_client bug: additionalProperties can be a bool,
17
+ # but the code assumes it's always a dict.
18
+ import gradio_client.utils as _gc_utils
19
+
20
+ _orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type
21
+
22
+
23
+ def _patched_json_schema_to_python_type(schema, defs=None):
24
+ if isinstance(schema, bool):
25
+ return "Any"
26
+ return _orig_json_schema_to_python_type(schema, defs)
27
+
28
+
29
+ _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
30
+ from generate_dance import load_model, generate, _setup_render_args
31
+ from render import MovieMaker
32
+
33
+ # Preload model once at startup
34
+ print("Loading model...")
35
+ MODEL = load_model()
36
+ print("Model loaded. Initializing renderer...")
37
+
38
+ # Create MovieMaker on the main thread so pyglet's signal handler works.
39
+ _setup_render_args()
40
+ VISUALIZER = MovieMaker(save_path=".")
41
+ print("Starting UI...")
42
+
43
+
44
+ def run(audio_path):
45
+ if audio_path is None:
46
+ raise gr.Error("Please upload a music file.")
47
+
48
+ logs = []
49
+
50
+ def log_fn(msg):
51
+ logs.append(msg)
52
+ print(msg)
53
+
54
+ songname = os.path.splitext(os.path.basename(audio_path))[0]
55
+ output_path = os.path.join(tempfile.gettempdir(), f"{songname}_dance.mp4")
56
+
57
+ generate(audio_path, output_path, model=MODEL, visualizer=VISUALIZER, log_fn=log_fn)
58
+
59
+ return output_path, "\n".join(logs)
60
+
61
+
62
+ with gr.Blocks(title="FineDance") as demo:
63
+ gr.Markdown("# FineDance\nUpload a music file to generate a 3D dance video.")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ audio_input = gr.Audio(
68
+ label="Upload Music",
69
+ type="filepath",
70
+ )
71
+ generate_btn = gr.Button("Generate Dance", variant="primary")
72
+
73
+ with gr.Column():
74
+ video_output = gr.Video(label="Generated Dance")
75
+ status_output = gr.Textbox(label="Status", lines=6, interactive=False)
76
+
77
+ generate_btn.click(
78
+ fn=run,
79
+ inputs=[audio_input],
80
+ outputs=[video_output, status_output],
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch(server_name="127.0.0.1")
args.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import yaml
3
+
4
+ def FineDance_parse_train_opt():
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--project", default="experiments/finedance_seq_120_genre/train", help="project/name")
7
+ parser.add_argument("--exp_name", default="finedance_seq_120_genre", help="save to project/name")
8
+ parser.add_argument("--feature_type", type=str, default="baseline")
9
+ parser.add_argument("--datasplit", type=str, default="cross_genre", choices=["cross_genre", "cross_dancer"])
10
+ parser.add_argument(
11
+ "--render_dir", type=str, default="experiments/finedance_seq_120_genre/renders", help="Sample render path"
12
+ )
13
+ parser.add_argument(
14
+ "--full_seq_len", type=int, default=120, help="full_seq_len"
15
+ )
16
+ parser.add_argument(
17
+ "--windows", type=int, default=10, help="windows"
18
+ )
19
+ parser.add_argument(
20
+ "--mix", action="store_true", help="Saves the motions for evaluation"
21
+ )
22
+ # parser.add_argument("--feature_type", type=str, default="jukebox")
23
+ parser.add_argument(
24
+ "--wandb_pj_name", type=str, default="finedance_seq", help="project name"
25
+ )
26
+ parser.add_argument("--batch_size", type=int, default=400, help="batch size") # default=64
27
+ parser.add_argument("--epochs", type=int, default=2000)
28
+ parser.add_argument(
29
+ "--save_interval",
30
+ type=int,
31
+ default=10, # default=100,
32
+ help='Log model after every "save_period" epoch',
33
+ )
34
+ parser.add_argument("--ema_interval", type=int, default=1, help="ema every x steps")
35
+ parser.add_argument(
36
+ "--checkpoint", type=str, default="", help="trained checkpoint path (optional)"
37
+ )
38
+ parser.add_argument(
39
+ "--do_normalize",
40
+ action="store_true",
41
+ help="normalize",
42
+ )
43
+ parser.add_argument(
44
+ "--nfeats", type=int, default=319, help="nfeats"
45
+ )
46
+ opt = parser.parse_args()
47
+ return opt
48
+
49
+ def FineDance_parse_test_opt():
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--feature_type", type=str, default="baseline")
52
+ parser.add_argument(
53
+ "--full_seq_len", type=int, default=120, help="full_seq_len"
54
+ )
55
+ parser.add_argument("--datasplit", type=str, default="cross_genre", choices=["cross_genre", "cross_dancer"])
56
+ parser.add_argument(
57
+ "--windows", type=int, default=10, help="windows"
58
+ )
59
+ parser.add_argument("--out_length", type=float, default=30, help="max. length of output, in seconds")
60
+ parser.add_argument(
61
+ "--render_dir", type=str, default="FineDance_test_renders/", help="Sample render path"
62
+ )
63
+ parser.add_argument(
64
+ "--checkpoint", type=str, default="assets/checkpoints/train-2000.pt", help="checkpoint"
65
+ )
66
+ parser.add_argument(
67
+ "--nfeats", type=int, default=319, help="nfeats"
68
+ )
69
+ parser.add_argument(
70
+ "--music_dir",
71
+ type=str,
72
+ default="data/finedance/music_wav",
73
+ help="folder containing input music",
74
+ )
75
+ parser.add_argument(
76
+ "--save_motions", action="store_true", help="Saves the motions for evaluation"
77
+ )
78
+ parser.add_argument(
79
+ "--motion_save_dir",
80
+ type=str,
81
+ default="eval/motions",
82
+ help="Where to save the motions",
83
+ )
84
+ parser.add_argument(
85
+ "--cache_features",
86
+ action="store_true",
87
+ help="Save the jukebox features for later reuse",
88
+ )
89
+ parser.add_argument(
90
+ "--do_normalize",
91
+ action="store_true",
92
+ help="normalize",
93
+ )
94
+ parser.add_argument(
95
+ "--no_render",
96
+ action="store_true",
97
+ help="Don't render the video",
98
+ )
99
+ parser.add_argument(
100
+ "--use_cached_features",
101
+ action="store_true",
102
+ help="Use precomputed features instead of music folder",
103
+ )
104
+ parser.add_argument(
105
+ "--feature_cache_dir",
106
+ type=str,
107
+ default="cached_features/",
108
+ help="Where to save/load the features",
109
+ )
110
+ opt = parser.parse_args()
111
+ return opt
112
+
113
+
114
+ def save_arguments_to_yaml(args, file_path):
115
+ arg_dict = vars(args) # 将Namespace对象转换为字典
116
+ yaml_str = yaml.dump(arg_dict, default_flow_style=False)
117
+
118
+ with open(file_path, 'w') as file:
119
+ file.write(yaml_str)
data/code/pre_motion.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+ import smplx, pickle
5
+ import torch
6
+ import sys
7
+ from tqdm import tqdm
8
+ import glob
9
+ import numpy as np
10
+
11
+ sys.path.append(os.getcwd())
12
+ from dataset.quaternion import ax_to_6v, ax_from_6v
13
+ from dataset.preprocess import Normalizer, vectorize_many
14
+
15
+
16
+ def motion_feats_extract(inputs_dir, outputs_dir):
17
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
+ print("extracting")
19
+ raw_fps = 30
20
+ data_fps = 30
21
+ data_fps <= raw_fps
22
+ if not os.path.exists(outputs_dir):
23
+ os.makedirs(outputs_dir)
24
+ # All motion is retargeted to this standard model.
25
+ smplx_model = smplx.SMPLX(model_path='assets/smpl_model/smplx', ext='npz', gender='neutral',
26
+ num_betas=10, flat_hand_mean=True, num_expression_coeffs=10, use_pca=False).eval().to(device)
27
+
28
+ motions = sorted(glob.glob(os.path.join(inputs_dir, "*.npy")))
29
+ for motion in tqdm(motions):
30
+ name = os.path.splitext(os.path.basename(motion))[0].split(".")[0]
31
+ print("name is", name)
32
+ data = np.load(motion, allow_pickle=True)
33
+ print(data.shape)
34
+ pos = data[:,:3] # length, c
35
+ q = data[:,3:]
36
+ root_pos = torch.Tensor(pos).to(device) # T, 3
37
+ length = root_pos.shape[0]
38
+ local_q_rot6d = torch.Tensor(q).to(device) # T, 312
39
+ print("local_q_rot6d", local_q_rot6d.shape)
40
+ local_q = local_q_rot6d.reshape(length, 52, 6).clone()
41
+ local_q = ax_from_6v(local_q).view(length, 156) # T, 156
42
+
43
+ smplx_output = smplx_model(
44
+ betas = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
45
+ transl = root_pos, # global translation
46
+ global_orient = local_q[:, :3],
47
+ body_pose = local_q[:, 3:66], # 21
48
+ jaw_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
49
+ leye_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
50
+ reye_pose= torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
51
+ left_hand_pose = local_q[:, 66:66+45], # 15
52
+ right_hand_pose = local_q[:, 66+45:], # 15
53
+ expression = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
54
+ return_verts = False
55
+ )
56
+
57
+
58
+ positions = smplx_output.joints.view(length, -1, 3) # bxt, j, 3
59
+ feet = positions[:, (7, 8, 10, 11)] # # 150, 4, 3
60
+ feetv = torch.zeros(feet.shape[:2], device=device) # 150, 4
61
+ feetv[:-1] = (feet[1:] - feet[:-1]).norm(dim=-1)
62
+ contacts = (feetv < 0.01).to(local_q) # cast to right dtype # b, 150, 4
63
+
64
+ mofea319 = torch.cat([contacts, root_pos, local_q_rot6d], dim=1)
65
+ assert mofea319.shape[1] == 319
66
+ mofea319 = mofea319.detach().cpu().numpy()
67
+ np.save(os.path.join(outputs_dir, name+'.npy'), mofea319)
68
+ return
69
+
70
+
71
+ if __name__ == "__main__":
72
+ motion_feats_extract("data/finedance/motion", "data/finedance/motion_fea319")
data/code/pre_music.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import os
4
+ import wave
5
+ from tqdm import tqdm
6
+ import librosa as lr
7
+
8
+ FPS = 30 #* 5
9
+ HOP_LENGTH = 512
10
+ SR = FPS * HOP_LENGTH
11
+ EPS = 1e-6
12
+
13
+ # HOP_LENGTH = 160
14
+ # SR = 16000
15
+
16
+ audio_dir = 'data/finedance/music_wav'
17
+ # audio_dir = '/home/human/datasets/aist_plusplus_final/music'
18
+ # audio_dir = "/home/human/datasets/data/Clip/music_clip_rhythm"
19
+
20
+ target_dir_ori = "data/finedance/music_wav_test"
21
+ os.makedirs(target_dir_ori, exist_ok=True)
22
+
23
+
24
+ # AIST++
25
+ def _get_tempo(audio_name):
26
+ """Get tempo (BPM) for a music by parsing music name."""
27
+ # a lot of stuff, only take the 5th element
28
+ audio_name = audio_name.split("_")[4]
29
+ assert len(audio_name) == 4
30
+ if audio_name[0:3] in [
31
+ "mBR",
32
+ "mPO",
33
+ "mLO",
34
+ "mMH",
35
+ "mLH",
36
+ "mWA",
37
+ "mKR",
38
+ "mJS",
39
+ "mJB",
40
+ ]:
41
+ return int(audio_name[3]) * 10 + 80
42
+ elif audio_name[0:3] == "mHO":
43
+ return int(audio_name[3]) * 5 + 110
44
+ else:
45
+ assert False, audio_name
46
+
47
+ for file in tqdm(os.listdir(audio_dir)):
48
+ audio_name = file[:-4]
49
+
50
+ save_path = os.path.join(target_dir_ori, f"{audio_name}.npy") ##存特征路径
51
+ music_file = os.path.join(audio_dir, file)
52
+
53
+
54
+ data, _ = librosa.load(music_file, sr=SR)
55
+
56
+ envelope = librosa.onset.onset_strength(y=data, sr=SR) # (seq_len,)
57
+ mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T # (seq_len, 20)
58
+ chroma = librosa.feature.chroma_cens(
59
+ y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12
60
+ ).T # (seq_len, 12)
61
+
62
+ peak_idxs = librosa.onset.onset_detect(
63
+ onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
64
+ )
65
+ peak_onehot = np.zeros_like(envelope, dtype=np.float32)
66
+ peak_onehot[peak_idxs] = 1.0 # (seq_len,)
67
+
68
+ try:
69
+ start_bpm = _get_tempo(audio_name)
70
+ except:
71
+ # determine manually
72
+ start_bpm = lr.beat.tempo(y=lr.load(music_file)[0])[0]
73
+
74
+ tempo, beat_idxs = librosa.beat.beat_track(
75
+ onset_envelope=envelope,
76
+ sr=SR,
77
+ hop_length=HOP_LENGTH,
78
+ start_bpm=start_bpm,
79
+ tightness=100,
80
+ )
81
+ beat_onehot = np.zeros_like(envelope, dtype=np.float32)
82
+ beat_onehot[beat_idxs] = 1.0 # (seq_len,)
83
+
84
+ audio_feature = np.concatenate(
85
+ [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
86
+ axis=-1,
87
+ )
88
+ np.save(save_path, audio_feature)
89
+
90
+
data/code/slice_music_motion.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+
5
+ music_dir = "data/finedance/music_npy"
6
+ motion_dir = "data/finedance/motion_fea319"
7
+
8
+ music_out = "data/finedance/div_by_time/music_npy_"
9
+ motion_out = "data/finedance/div_by_time/motion_fea319_"
10
+
11
+ timelen = 120
12
+
13
+
14
+ music_out = music_out + str(timelen)
15
+ motion_out = motion_out + str(timelen)
16
+ if not os.path.exists(music_out):
17
+ os.makedirs(music_out)
18
+ if not os.path.exists(motion_out):
19
+ os.makedirs(motion_out)
20
+
21
+
22
+ for file in os.listdir(motion_dir):
23
+ if file[-3:] != 'npy':
24
+ print(file[-3:])
25
+ continue
26
+ name = file.split(".")[0]
27
+ music_fea = np.load(os.path.join(music_dir, file))
28
+ motion_fea = np.load(os.path.join(motion_dir, file))
29
+ max_length = min(music_fea.shape[0], motion_fea.shape[0])
30
+
31
+ iters = (max_length//timelen)
32
+ max_length = iters*timelen
33
+ music_fea = music_fea[:max_length, :]
34
+ motion_fea = motion_fea[:max_length, :]
35
+
36
+ for i in range(iters):
37
+ music_clip = music_fea[i*timelen: (i+1)*timelen, :]
38
+ motion_clip = motion_fea[i*timelen: (i+1)*timelen, :]
39
+ np.save(os.path.join(music_out, name + "z@" + str(i).zfill(3) + ".npy"), music_clip)
40
+ np.save(os.path.join(motion_out, name + "z@" + str(i).zfill(3) + ".npy"), motion_clip)
41
+
dataset/FineDance_dataset.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ import os
5
+ from tqdm import tqdm
6
+ import json
7
+
8
+ import sys
9
+ sys.path.insert(0,'.')
10
+
11
+ SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
12
+
13
+ SMPLX_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13,
14
+ 15, 17, 16, 19, 18, 21, 20, 22, 24, 23,
15
+ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
16
+ 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
17
+ SMPLX_POSE_FLIP_PERM = []
18
+ for i in SMPLX_JOINTS_FLIP_PERM:
19
+ SMPLX_POSE_FLIP_PERM.append(3*i)
20
+ SMPLX_POSE_FLIP_PERM.append(3*i+1)
21
+ SMPLX_POSE_FLIP_PERM.append(3*i+2)
22
+
23
+ def flip_pose(pose):
24
+ #Flip pose.The flipping is based on SMPLX parameters.
25
+ pose = pose[:,SMPLX_POSE_FLIP_PERM]
26
+ # we also negate the second and the third dimension of the axis-angle
27
+ pose[:,1::3] = -pose[:,1::3]
28
+ pose[:,2::3] = -pose[:,2::3]
29
+ return pose
30
+
31
+ def get_train_test_list(datasplit):
32
+ all_list = []
33
+ train_list = []
34
+ for i in range(1,212):
35
+ all_list.append(str(i).zfill(3))
36
+
37
+ if datasplit == "cross_genre":
38
+ test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
39
+ ignor_list = ["116", "117", "118", "119", "120", "121", "122", "123", "202"]+["130"]
40
+ elif datasplit == "cross_dancer":
41
+ test_list = ['001','002','003','004','005','006','007','008','009','010','011','012','013','124','126','128','130','132']
42
+ ignor_list = ['115','117','119','121','122','135','137','139','141','143','145','147'] + ["116", "118", "120", "123", "202", "159"]+["130"] # 前一个列表为val set,后一个列表为ignore set
43
+ else:
44
+ raise("error of data split!")
45
+ for one in all_list:
46
+ if one not in test_list:
47
+ if one not in ignor_list:
48
+ train_list.append(one)
49
+
50
+ return train_list, test_list, ignor_list
51
+
52
+ class FineDance_Smpl(data.Dataset):
53
+ def __init__(self, args, istrain):
54
+ self.motion_dir = './data/finedance/motion_fea319'
55
+ self.music_dir = './data/finedance/music_npy'
56
+ self.istrain = istrain
57
+ self.seq_len = args.full_seq_len
58
+ slide = args.full_seq_len // args.windows
59
+
60
+ self.motion_index = []
61
+ self.music_index = []
62
+ self.name = []
63
+ motion_all = []
64
+ music_all = []
65
+
66
+ train_list, test_list, ignor_list = get_train_test_list(args.datasplit)
67
+ if self.istrain:
68
+ self.datalist= train_list
69
+ else:
70
+ self.datalist = test_list
71
+
72
+ total_length = 0 # 将数据集中的所有motion用同一个index索引
73
+
74
+ for name in tqdm(self.datalist):
75
+ save_name = name
76
+ name = name + ".npy"
77
+
78
+ if name[:-4] in ignor_list:
79
+ continue
80
+
81
+ motion = np.load(os.path.join(self.motion_dir, name))
82
+ music = np.load(os.path.join(self.music_dir, name))
83
+
84
+ min_all_len = min(motion.shape[0], music.shape[0])
85
+ motion = motion[:min_all_len]
86
+ if motion.shape[-1] == 168:
87
+ motion = np.concatenate([motion[:,:69], motion[:,78:]], axis=1) # 22, 25
88
+ elif motion.shape[-1] == 319:
89
+ pass
90
+ elif motion.shape[-1] == 315:
91
+ pass
92
+ # motion = np.concatenate([motion[:,:135], motion[:,153:]], axis=1) #
93
+ else:
94
+ raise("input motion shape error! not 168 or 319!")
95
+ music = music[:min_all_len] # motion = motion[:min_all_len]
96
+ nums = (min_all_len-self.seq_len) // slide + 1 # 舍弃了最后一段不满seq_len的motion
97
+
98
+ if self.istrain:
99
+ clip_index = []
100
+ for i in range(nums):
101
+ motion_clip = motion[i * slide: i * slide + self.seq_len]
102
+ if motion_clip.std(axis=0).mean() > 0.07: # 判断是否为有效motion,如果耗费时间,可以考虑删掉
103
+ clip_index.append(i)
104
+ index = np.array(clip_index) * slide + total_length # clip_index为local index
105
+ index_ = np.array(clip_index) * slide
106
+ else:
107
+ index = np.arange(nums) * slide + total_length
108
+ index_ = np.arange(nums) * slide
109
+
110
+ motion_all.append(motion)
111
+ music_all.append(music)
112
+
113
+ if args.mix:
114
+ motion_index = []
115
+ music_index = []
116
+ num = (len(index) - 1) // 8 + 1
117
+ for i in range(num):
118
+ motion_index_tmp, music_index_tmp = np.meshgrid(index[i*8:(i+1)*8], index[i*8:(i+1)*8]) # 这里i有问题?似乎没有
119
+ motion_index += motion_index_tmp.reshape((-1)).tolist()
120
+ music_index += music_index_tmp.reshape((-1)).tolist()
121
+ index_tmp = np.meshgrid(index_[i*8:(i+1)*8])
122
+ index_ += index_tmp.reshape((-1)).tolist()
123
+ else:
124
+ motion_index = index.tolist()
125
+ music_index = index.tolist()
126
+ index_ = index_.tolist()
127
+ index_ = [save_name + "_" + str(element).zfill(5) for element in index_]
128
+
129
+ self.motion_index += motion_index
130
+ self.music_index += music_index
131
+ total_length += min_all_len
132
+ self.name += index_
133
+
134
+ self.motion = np.concatenate(motion_all, axis=0).astype(np.float32)
135
+ self.music = np.concatenate(music_all, axis=0).astype(np.float32)
136
+
137
+ self.len = len(self.motion_index)
138
+ print(f'FineDance has {self.len} samples..')
139
+
140
+ def __len__(self):
141
+ return self.len
142
+
143
+ def __getitem__(self, index):
144
+ motion_index = self.motion_index[index]
145
+ music_index = self.music_index[index]
146
+ motion = self.motion[motion_index:motion_index+self.seq_len]
147
+ if motion.shape[-1] == 319 or motion.shape[-1] == 139:
148
+ motion[:, 4:7] = motion[:, 4:7] - motion[:1, 4:7] # The first 4 dimension are foot contact
149
+ else:
150
+ motion[:, :3] = motion[:, :3] - motion[:1, :3]
151
+ music = self.music[music_index:music_index+self.seq_len]
152
+ filename = self.name[index]
153
+ # if np.random.rand(1) > 0.5:
154
+ # motion = motion[:,self.mirror_idx]
155
+ return motion, music, filename
156
+
157
+
158
+ if __name__ == '__main__':
159
+ data_split = {}
160
+ all_list = []
161
+ train_list = []
162
+ for i in range(1,212):
163
+ all_list.append(str(i).zfill(3))
164
+ test_list = ["001","002","003","004","005","006","007","008","009","010","011","012","013","124","126","128","130","132"]
165
+ val_list = ["115","117","119","121","122","135","137","139","141","143","145","147"]
166
+ for one in all_list:
167
+ if one not in test_list:
168
+ if one not in val_list:
169
+ train_list.append(one)
170
+
171
+ data_split["train"] = train_list
172
+ data_split["test"] = test_list
173
+ data_split["val"] = val_list
174
+ data_split["ignore"] = ["116", "117", "118", "119", "120", "121", "122", "123", "202"]
175
+
176
+ with open("data_crossdancer.json", "w") as f:
177
+ json.dump(data_split,f)
178
+
179
+
180
+ print(train_list)
dataset/__init__.py ADDED
File without changes
dataset/preprocess.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import torch
7
+
8
+ from .scaler import MinMaxScaler
9
+ import pickle
10
+
11
+
12
+ def increment_path(path, exist_ok=False, sep="", mkdir=False):
13
+ # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
14
+ path = Path(path) # os-agnostic
15
+ if path.exists() and not exist_ok:
16
+ suffix = path.suffix
17
+ path = path.with_suffix("")
18
+ dirs = glob.glob(f"{path}{sep}*") # similar paths
19
+ matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
20
+ i = [int(m.groups()[0]) for m in matches if m] # indices
21
+ n = max(i) + 1 if i else 2 # increment number
22
+ path = Path(f"{path}{sep}{n}{suffix}") # update path
23
+ dir = path if path.suffix == "" else path.parent # directory
24
+ if not dir.exists() and mkdir:
25
+ dir.mkdir(parents=True, exist_ok=True) # make directory
26
+ return path
27
+
28
+
29
+ class Normalizer:
30
+ def __init__(self, data):
31
+ flat = data.reshape(-1, data.shape[-1]) # bxt , 151
32
+ self.scaler = MinMaxScaler((-1, 1), clip=True)
33
+ self.scaler.fit(flat)
34
+
35
+ def normalize(self, x):
36
+ batch, seq, ch = x.shape
37
+ x = x.reshape(-1, ch)
38
+ return self.scaler.transform(x).reshape((batch, seq, ch))
39
+
40
+ def unnormalize(self, x):
41
+ batch, seq, ch = x.shape
42
+ x = x.reshape(-1, ch)
43
+ x = torch.clip(x, -1, 1) # clip to force compatibility
44
+ return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
45
+
46
+
47
+ class My_Normalizer:
48
+ def __init__(self, data):
49
+ if isinstance(data, str):
50
+ self.scaler = MinMaxScaler((-1, 1), clip=True)
51
+ with open(data, 'rb') as f:
52
+ normalizer_state_dict = pickle.load(f)
53
+ # normalizer_state_dict = torch.load(data)
54
+ self.scaler.scale_ = normalizer_state_dict["scale"]
55
+ self.scaler.min_ = normalizer_state_dict["min"]
56
+ else:
57
+ flat = data.reshape(-1, data.shape[-1]) # bxt , 151
58
+ self.scaler = MinMaxScaler((-1, 1), clip=True)
59
+ self.scaler.fit(flat)
60
+
61
+ def normalize(self, x):
62
+ if len(x.shape) == 3:
63
+ batch, seq, ch = x.shape
64
+ x = x.reshape(-1, ch)
65
+ return self.scaler.transform(x).reshape((batch, seq, ch))
66
+ elif len(x.shape) == 2:
67
+ batch, ch = x.shape
68
+ return self.scaler.transform(x)
69
+ else:
70
+ raise("input error!")
71
+
72
+ def unnormalize(self, x):
73
+ if len(x.shape) == 3:
74
+ batch, seq, ch = x.shape
75
+ x = x.reshape(-1, ch)
76
+ x = torch.clip(x, -1, 1) # clip to force compatibility
77
+ return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
78
+ elif len(x.shape) == 2:
79
+ x = torch.clip(x, -1, 1)
80
+ return self.scaler.inverse_transform(x)
81
+ else:
82
+ raise("input error!")
83
+
84
+
85
+ def vectorize_many(data):
86
+ # given a list of batch x seqlen x joints? x channels, flatten all to batch x seqlen x -1, concatenate
87
+ batch_size = data[0].shape[0]
88
+ seq_len = data[0].shape[1]
89
+
90
+ out = [x.reshape(batch_size, seq_len, -1).contiguous() for x in data]
91
+
92
+ global_pose_vec_gt = torch.cat(out, dim=2)
93
+ return global_pose_vec_gt
dataset/quaternion.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle,
3
+ matrix_to_quaternion, matrix_to_rotation_6d,
4
+ quaternion_to_matrix, rotation_6d_to_matrix)
5
+
6
+
7
+ def quat_to_6v(q):
8
+ assert q.shape[-1] == 4
9
+ mat = quaternion_to_matrix(q)
10
+ mat = matrix_to_rotation_6d(mat)
11
+ return mat
12
+
13
+
14
+ def quat_from_6v(q):
15
+ assert q.shape[-1] == 6
16
+ mat = rotation_6d_to_matrix(q)
17
+ quat = matrix_to_quaternion(mat)
18
+ return quat
19
+
20
+
21
+ def ax_to_6v(q):
22
+ assert q.shape[-1] == 3
23
+ mat = axis_angle_to_matrix(q)
24
+ mat = matrix_to_rotation_6d(mat)
25
+ return mat
26
+
27
+
28
+ def ax_from_6v(q):
29
+ assert q.shape[-1] == 6
30
+ mat = rotation_6d_to_matrix(q)
31
+ ax = matrix_to_axis_angle(mat)
32
+ return ax
33
+
34
+
35
+ def quat_slerp(x, y, a):
36
+ """
37
+ Performs spherical linear interpolation (SLERP) between x and y, with proportion a
38
+
39
+ :param x: quaternion tensor (N, S, J, 4)
40
+ :param y: quaternion tensor (N, S, J, 4)
41
+ :param a: interpolation weight (S, )
42
+ :return: tensor of interpolation results
43
+ """
44
+ len = torch.sum(x * y, axis=-1)
45
+
46
+ neg = len < 0.0
47
+ len[neg] = -len[neg]
48
+ y[neg] = -y[neg]
49
+
50
+ a = torch.zeros_like(x[..., 0]) + a
51
+
52
+ amount0 = torch.zeros_like(a)
53
+ amount1 = torch.zeros_like(a)
54
+
55
+ linear = (1.0 - len) < 0.01
56
+ omegas = torch.arccos(len[~linear])
57
+ sinoms = torch.sin(omegas)
58
+
59
+ amount0[linear] = 1.0 - a[linear]
60
+ amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
61
+
62
+ amount1[linear] = a[linear]
63
+ amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
64
+
65
+ # reshape
66
+ amount0 = amount0[..., None]
67
+ amount1 = amount1[..., None]
68
+
69
+ res = amount0 * x + amount1 * y
70
+
71
+ return res
dataset/scaler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
5
+ # if we are fitting on 1D arrays, scale might be a scalar
6
+ if constant_mask is None:
7
+ # Detect near constant values to avoid dividing by a very small
8
+ # value that could lead to surprising results and numerical
9
+ # stability issues.
10
+ constant_mask = scale < 10 * torch.finfo(scale.dtype).eps
11
+
12
+ if copy:
13
+ # New array to avoid side-effects
14
+ scale = scale.clone()
15
+ scale[constant_mask] = 1.0
16
+ return scale
17
+
18
+
19
+ class MinMaxScaler:
20
+ _parameter_constraints: dict = {
21
+ "feature_range": [tuple],
22
+ "copy": ["boolean"],
23
+ "clip": ["boolean"],
24
+ }
25
+
26
+ def __init__(self, feature_range=(0, 1), *, copy=True, clip=False):
27
+ self.feature_range = feature_range
28
+ self.copy = copy
29
+ self.clip = clip
30
+
31
+ def _reset(self):
32
+ """Reset internal data-dependent state of the scaler, if necessary.
33
+ __init__ parameters are not touched.
34
+ """
35
+ # Checking one attribute is enough, because they are all set together
36
+ # in partial_fit
37
+ if hasattr(self, "scale_"):
38
+ del self.scale_
39
+ del self.min_
40
+ del self.n_samples_seen_
41
+ del self.data_min_
42
+ del self.data_max_
43
+ del self.data_range_
44
+
45
+ def fit(self, X):
46
+ # Reset internal state before fitting
47
+ self._reset()
48
+ return self.partial_fit(X)
49
+
50
+ def partial_fit(self, X):
51
+ feature_range = self.feature_range
52
+ if feature_range[0] >= feature_range[1]:
53
+ raise ValueError(
54
+ "Minimum of desired feature range must be smaller than maximum. Got %s."
55
+ % str(feature_range)
56
+ )
57
+
58
+ data_min = torch.min(X, axis=0)[0]
59
+ data_max = torch.max(X, axis=0)[0]
60
+
61
+ self.n_samples_seen_ = X.shape[0]
62
+
63
+ data_range = data_max - data_min
64
+ self.scale_ = (feature_range[1] - feature_range[0]) / _handle_zeros_in_scale(
65
+ data_range, copy=True
66
+ )
67
+ self.min_ = feature_range[0] - data_min * self.scale_
68
+ self.data_min_ = data_min
69
+ self.data_max_ = data_max
70
+ self.data_range_ = data_range
71
+ return self
72
+
73
+ def transform(self, X):
74
+ X *= self.scale_.to(X.device)
75
+ X += self.min_.to(X.device)
76
+ if self.clip:
77
+ torch.clip(X, self.feature_range[0], self.feature_range[1], out=X)
78
+ return X
79
+
80
+ def inverse_transform(self, X):
81
+ X -= self.min_[-X.shape[1] :].to(X.device)
82
+ X /= self.scale_[-X.shape[1] :].to(X.device)
83
+ return X
environment.yaml ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: FineNet
2
+ channels:
3
+ - anaconda
4
+ - pytorch
5
+ - conda-forge
6
+ - https://repo.anaconda.com/pkgs/main
7
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/peterjc123/
8
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
9
+ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
10
+ - defaults
11
+ dependencies:
12
+ - _libgcc_mutex=0.1=main
13
+ - _openmp_mutex=5.1=1_gnu
14
+ - aiohttp=3.8.1=py38h0a891b7_1
15
+ - aiosignal=1.3.1=pyhd8ed1ab_0
16
+ - asttokens=2.2.1=pyhd8ed1ab_0
17
+ - async-timeout=4.0.3=pyhd8ed1ab_0
18
+ - attrs=23.1.0=pyh71513ae_1
19
+ - backcall=0.2.0=pyh9f0ad1d_0
20
+ - backports=1.0=pyhd8ed1ab_3
21
+ - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
22
+ - blas=1.0=mkl
23
+ - blinker=1.7.0=pyhd8ed1ab_0
24
+ - brotlipy=0.7.0=py38h0a891b7_1004
25
+ - bzip2=1.0.8=h7f98852_4
26
+ - c-ares=1.18.1=h7f98852_0
27
+ - ca-certificates=2023.11.17=hbcca054_0
28
+ - certifi=2023.11.17=pyhd8ed1ab_0
29
+ - cffi=1.15.0=py38h3931269_0
30
+ - charset-normalizer=2.1.1=pyhd8ed1ab_0
31
+ - colorama=0.4.6=pyhd8ed1ab_0
32
+ - cpuonly=1.0=0
33
+ - cryptography=37.0.2=py38h2b5fc30_0
34
+ - cudatoolkit=11.3.1=h9edb442_10
35
+ - cycler=0.11.0=pyhd8ed1ab_0
36
+ - dataclasses=0.8=pyhc8e2a94_3
37
+ - debugpy=1.5.1=py38h295c915_0
38
+ - entrypoints=0.4=pyhd8ed1ab_0
39
+ - executing=1.2.0=pyhd8ed1ab_0
40
+ - ffmpeg=4.3=hf484d3e_0
41
+ - freetype=2.10.4=h0708190_1
42
+ - frozenlist=1.3.0=py38h0a891b7_1
43
+ - fsspec=2023.5.0=pyh1a96a4e_0
44
+ - future=0.18.3=pyhd8ed1ab_0
45
+ - geos=3.10.2=h9c3ff4c_0
46
+ - giflib=5.2.1=h5eee18b_1
47
+ - gmp=6.2.1=h58526e2_0
48
+ - gnutls=3.6.13=h85f3911_1
49
+ - icu=67.1=he1b5a44_0
50
+ - idna=3.4=pyhd8ed1ab_0
51
+ - importlib-metadata=6.8.0=pyha770c72_0
52
+ - intel-openmp=2021.4.0=h06a4308_3561
53
+ - ipykernel=6.14.0=py38h7f3c49e_0
54
+ - ipython=8.4.0=py38h578d9bd_0
55
+ - jedi=0.18.2=pyhd8ed1ab_0
56
+ - jpeg=9e=h166bdaf_1
57
+ - jupyter_client=7.0.6=pyhd8ed1ab_0
58
+ - jupyter_core=4.12.0=py38h578d9bd_0
59
+ - kiwisolver=1.4.4=py38h6a678d5_0
60
+ - lame=3.100=h7f98852_1001
61
+ - lcms2=2.12=hddcbb42_0
62
+ - ld_impl_linux-64=2.38=h1181459_1
63
+ - libffi=3.4.2=h6a678d5_6
64
+ - libgcc-ng=11.2.0=h1234567_1
65
+ - libgfortran-ng=7.5.0=ha8ba4b0_17
66
+ - libgfortran4=7.5.0=ha8ba4b0_17
67
+ - libgomp=11.2.0=h1234567_1
68
+ - libiconv=1.17=h166bdaf_0
69
+ - libpng=1.6.37=h21135ba_2
70
+ - libprotobuf=3.18.0=h780b84a_1
71
+ - libsodium=1.0.18=h36c2ea0_1
72
+ - libstdcxx-ng=11.2.0=h1234567_1
73
+ - libtiff=4.2.0=hf544144_3
74
+ - libuv=1.43.0=h7f98852_0
75
+ - libwebp=1.2.2=h55f646e_0
76
+ - libwebp-base=1.2.2=h7f98852_1
77
+ - lightning-utilities=0.8.0=pyhd8ed1ab_0
78
+ - lz4-c=1.9.3=h9c3ff4c_1
79
+ - mapbox_earcut=1.0.0=py38h43d8883_3
80
+ - matplotlib-base=3.2.2=py38h5d868c9_1
81
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
82
+ - mkl=2021.4.0=h06a4308_640
83
+ - mkl-service=2.4.0=py38h95df7f1_0
84
+ - mkl_fft=1.3.1=py38h8666266_1
85
+ - mkl_random=1.2.2=py38h1abd341_0
86
+ - mpi=1.0=mpich
87
+ - mpi4py=3.1.4=py38hfc96bbd_0
88
+ - mpich=3.3.2=hc856adb_0
89
+ - multidict=6.0.2=py38h0a891b7_1
90
+ - ncurses=6.3=h5eee18b_3
91
+ - nest-asyncio=1.5.6=pyhd8ed1ab_0
92
+ - nettle=3.6=he412f7d_0
93
+ - networkx=3.1=pyhd8ed1ab_0
94
+ - ninja=1.11.0=h924138e_0
95
+ - oauthlib=3.2.2=pyhd8ed1ab_0
96
+ - olefile=0.46=pyh9f0ad1d_1
97
+ - openh264=2.1.1=h780b84a_0
98
+ - openjpeg=2.4.0=hb52868f_1
99
+ - openssl=1.1.1w=h7f8727e_0
100
+ - parso=0.8.3=pyhd8ed1ab_0
101
+ - pexpect=4.8.0=pyh1a96a4e_2
102
+ - pickleshare=0.7.5=py_1003
103
+ - prompt-toolkit=3.0.39=pyha770c72_0
104
+ - ptyprocess=0.7.0=pyhd3deb0d_0
105
+ - pure_eval=0.2.2=pyhd8ed1ab_0
106
+ - pyasn1=0.5.0=pyhd8ed1ab_0
107
+ - pyasn1-modules=0.3.0=pyhd8ed1ab_0
108
+ - pycparser=2.21=pyhd8ed1ab_0
109
+ - pyjwt=2.8.0=pyhd8ed1ab_0
110
+ - pyopenssl=22.0.0=pyhd8ed1ab_1
111
+ - pyrender=0.1.45=pyh8a188c0_3
112
+ - pysocks=1.7.1=pyha2e5f31_6
113
+ - python=3.8.15=h7a1cb2a_2
114
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
115
+ - python_abi=3.8=2_cp38
116
+ - pytorch-lightning=1.5.8=pyhd8ed1ab_0
117
+ - pytorch-mutex=1.0=cuda
118
+ - pyu2f=0.1.5=pyhd8ed1ab_0
119
+ - pyyaml=6.0=py38h0a891b7_4
120
+ - pyzmq=19.0.2=py38ha71036d_2
121
+ - readline=8.2=h5eee18b_0
122
+ - requests=2.28.2=pyhd8ed1ab_0
123
+ - requests-oauthlib=1.3.1=pyhd8ed1ab_0
124
+ - rsa=4.9=pyhd8ed1ab_0
125
+ - six=1.16.0=pyh6c4a22f_0
126
+ - sqlite=3.40.0=h5082296_0
127
+ - stack_data=0.6.2=pyhd8ed1ab_0
128
+ - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
129
+ - tk=8.6.12=h1ccaba5_0
130
+ - tornado=6.1=py38h0a891b7_3
131
+ - tqdm=4.65.0=pyhd8ed1ab_1
132
+ - traitlets=5.9.0=pyhd8ed1ab_0
133
+ - typing-extensions=4.4.0=hd8ed1ab_0
134
+ - typing_extensions=4.4.0=pyha770c72_0
135
+ - urllib3=1.26.14=pyhd8ed1ab_0
136
+ - wcwidth=0.2.6=pyhd8ed1ab_0
137
+ - xz=5.2.8=h5eee18b_0
138
+ - yaml=0.2.5=h7f98852_2
139
+ - yarl=1.7.2=py38h0a891b7_2
140
+ - zeromq=4.3.4=h9c3ff4c_1
141
+ - zlib=1.2.13=h5eee18b_0
142
+ - zstd=1.5.0=ha95c52a_0
143
+ - pip:
144
+ - absl-py==1.4.0
145
+ - accelerate==0.19.0
146
+ - alembic==1.12.0
147
+ - aniposelib==0.4.3
148
+ - antlr4-python3-runtime==4.8
149
+ - appdirs==1.4.4
150
+ - audioread==3.0.0
151
+ - autopage==0.5.1
152
+ - backoff==2.2.1
153
+ - beautifulsoup4==4.12.2
154
+ - bertopic==0.15.0
155
+ - blobfile==2.0.2
156
+ - boilerpy3==1.0.6
157
+ - cachetools==5.3.1
158
+ - canals==0.2.2
159
+ - cattrs==23.1.2
160
+ - chumpy==0.69
161
+ - click==8.1.3
162
+ - cliff==4.3.0
163
+ - clip==1.0
164
+ - cmaes==0.10.0
165
+ - cmd2==2.4.3
166
+ - colorlog==6.7.0
167
+ - commonmark==0.9.1
168
+ - configer==1.3.1
169
+ - configparser==5.3.0
170
+ - contourpy==1.0.7
171
+ - coremltools==6.1
172
+ - cython==0.29.35
173
+ - decorator==4.4.2
174
+ - diffusers==0.16.1
175
+ - dill==0.3.6
176
+ - docker-pycreds==0.4.0
177
+ - docopt==0.6.2
178
+ - easydict==1.7
179
+ - einops==0.6.1
180
+ - etils==0.9.0
181
+ - events==0.4
182
+ - exceptiongroup==1.1.2
183
+ - farm-haystack==1.18.1
184
+ - filelock==3.12.0
185
+ - fire==0.1.3
186
+ - fonttools==4.39.4
187
+ - freetype-py==2.3.0
188
+ - ftfy==6.1.1
189
+ - fvcore==0.1.5.post20221221
190
+ - gdown==4.7.1
191
+ - gitdb==4.0.10
192
+ - gitpython==3.1.31
193
+ - google-auth==2.22.0
194
+ - google-auth-oauthlib==1.0.0
195
+ - googleapis-common-protos==1.57.0
196
+ - greenlet==2.0.2
197
+ - grpcio==1.56.0
198
+ - h5py==3.9.0
199
+ - hdbscan==0.8.33
200
+ - huggingface-hub==0.14.1
201
+ - hydra==2.5
202
+ - hydra-colorlog==1.1.0.dev1
203
+ - hydra-core==1.1.0rc1
204
+ - hydra-optuna-sweeper==1.1.0.dev2
205
+ - imageio==2.27.0
206
+ - imageio-ffmpeg==0.4.9
207
+ - importlib-resources==5.10.1
208
+ - inflect==7.0.0
209
+ - iopath==0.1.10
210
+ - joblib==1.2.0
211
+ - json-tricks==3.17.1
212
+ - jsonschema==4.18.4
213
+ - jsonschema-specifications==2023.7.1
214
+ - jukebox==1.0
215
+ - lazy-imports==0.3.1
216
+ - lazy-loader==0.2
217
+ - librosa==0.7.2
218
+ - llvmlite==0.31.0
219
+ - lxml==4.9.2
220
+ - mako==1.2.4
221
+ - markdown==3.4.3
222
+ - markupsafe==2.1.3
223
+ - matplotlib==3.7.3
224
+ - monotonic==1.6
225
+ - more-itertools==10.0.0
226
+ - moviepy==1.0.3
227
+ - mpmath==1.2.1
228
+ - msgpack==1.0.5
229
+ - multiprocess==0.70.14
230
+ - netifaces==0.11.0
231
+ - nltk==3.8.1
232
+ - num2words==0.5.12
233
+ - numba==0.48.0
234
+ - numpy==1.24.4
235
+ - omegaconf==2.1.0rc1
236
+ - onnx==1.12.0
237
+ - onnxoptimizer==0.3.2
238
+ - onnxsim==0.4.10
239
+ - opencv-contrib-python==4.8.0.74
240
+ - opencv-python==4.7.0.72
241
+ - optuna==2.4.0
242
+ - p-tqdm==1.4.0
243
+ - packaging==22.0
244
+ - pandas==1.2.4
245
+ - pathos==0.3.0
246
+ - pathtools==0.1.2
247
+ - pbr==5.11.1
248
+ - pickle5==0.0.11
249
+ - pillow==9.5.0
250
+ - pip==23.3.1
251
+ - pkgutil-resolve-name==1.3.10
252
+ - platformdirs==3.9.1
253
+ - plotly==5.17.0
254
+ - pooch==1.6.0
255
+ - portalocker==2.7.0
256
+ - posthog==3.0.1
257
+ - pox==0.3.2
258
+ - ppft==1.7.6.6
259
+ - prettytable==3.9.0
260
+ - proglog==0.1.10
261
+ - promise==2.3
262
+ - prompthub-py==4.0.0
263
+ - protobuf==3.20.1
264
+ - psutil==5.9.5
265
+ - publicip==1.0.1
266
+ - pycocotools==2.0.6
267
+ - pycryptodomex==3.17
268
+ - pydantic==1.10.11
269
+ - pydeprecate==0.3.2
270
+ - pydub==0.25.1
271
+ - pyglet==1.4.0b1
272
+ - pygments==2.13.0
273
+ - pynndescent==0.5.10
274
+ - pyopengl==3.1.0
275
+ - pyopengl-accelerate==3.1.7
276
+ - pyparsing==3.0.9
277
+ - pyperclip==1.8.2
278
+ - python-dotenv==0.17.1
279
+ - pytorch3d==0.3.0
280
+ - pytz==2022.6
281
+ - pywavelets==1.4.1
282
+ - quantulum3==0.9.0
283
+ - rank-bm25==0.2.2
284
+ - referencing==0.30.0
285
+ - regex==2023.5.5
286
+ - requests-cache==0.9.8
287
+ - resampy==0.3.1
288
+ - rich==12.6.0
289
+ - rpds-py==0.9.2
290
+ - safetensors==0.3.1
291
+ - scikit-image==0.18.0
292
+ - scikit-learn==1.2.2
293
+ - scipy==1.10.1
294
+ - sentence-transformers==2.2.2
295
+ - sentencepiece==0.1.99
296
+ - sentry-sdk==1.25.0
297
+ - setproctitle==1.3.2
298
+ - setuptools==68.2.2
299
+ - shapely==2.0.1
300
+ - smmap==5.0.0
301
+ - smplx==0.1.28
302
+ - soundfile==0.10.3.post1
303
+ - soupsieve==2.4.1
304
+ - soxr==0.3.5
305
+ - sqlalchemy==2.0.20
306
+ - sseclient-py==1.7.2
307
+ - stevedore==5.1.0
308
+ - sympy==1.11.1
309
+ - tabulate==0.9.0
310
+ - tbb==2021.10.0
311
+ - tenacity==8.2.2
312
+ - tensorboard==2.13.0
313
+ - tensorboard-data-server==0.7.1
314
+ - tensorboardx==1.6
315
+ - tensorflow-datasets==4.7.0
316
+ - tensorflow-metadata==1.12.0
317
+ - tf2onnx==1.13.0
318
+ - threadpoolctl==3.1.0
319
+ - tifffile==2023.7.4
320
+ - tiktoken==0.4.0
321
+ - timm==0.4.5
322
+ - tokenizers==0.13.3
323
+ - toml==0.10.2
324
+ - torch==1.13.1+cu116
325
+ - torch-tb-profiler==0.4.1
326
+ - torchaudio==0.13.1+cu116
327
+ - torchgeometry==0.1.2
328
+ - torchmetrics==0.7.0
329
+ - torchvision==0.14.1+cu116
330
+ - transformers==4.30.1
331
+ - transforms3d==0.4.1
332
+ - trimesh==3.9.24
333
+ - tzdata==2023.3
334
+ - umap-learn==0.5.4
335
+ - unidecode==1.1.1
336
+ - url-normalize==1.4.3
337
+ - wandb==0.15.2
338
+ - werkzeug==2.3.6
339
+ - wget==3.2
340
+ - wheel==0.41.2
341
+ - yacs==0.1.8
342
+ - zipp==3.16.2
343
+ prefix: /home/lrh/.conda/envs/py38
environment_macos.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: FineNet
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - python=3.9
8
+ - numpy
9
+ - scipy
10
+ - matplotlib
11
+ - pandas
12
+ - pyyaml
13
+ - h5py
14
+ - tqdm
15
+ - ipython
16
+ - jupyter_client
17
+ - cython
18
+ - ffmpeg
19
+ - pip
20
+ - pip:
21
+ - torch
22
+ - torchaudio
23
+ - torchvision
24
+ - pytorch-lightning==1.9.5
25
+ - torchmetrics==0.11.4
26
+ - accelerate
27
+ - einops
28
+ - smplx
29
+ - trimesh
30
+ - pyrender
31
+ - opencv-python
32
+ - opencv-contrib-python
33
+ - scikit-learn
34
+ - scikit-image
35
+ - transformers
36
+ - diffusers
37
+ - librosa
38
+ - soundfile
39
+ - moviepy
40
+ - imageio
41
+ - imageio-ffmpeg
42
+ - hydra-core==1.3.2
43
+ - omegaconf==2.3.0
44
+ - wandb
45
+ - tensorboard
46
+ - tensorboardx
47
+ - easydict
48
+ - fire
49
+ - ftfy
50
+ - regex
51
+ - pillow
52
+ - plotly
53
+ - gdown
54
+ - huggingface-hub
55
+ - safetensors
56
+ - sentence-transformers
57
+ - pydub
58
+ - json-tricks
59
+ - yacs
60
+ - fvcore
61
+ - iopath
62
+ - tabulate
63
+ - rich
64
+ - click
generate_all.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os,sys
3
+ from functools import cmp_to_key
4
+ from pathlib import Path
5
+
6
+ # import jukemirlib
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from args import FineDance_parse_test_opt
12
+ from train_seq import EDGE
13
+ from dataset.FineDance_dataset import get_train_test_list
14
+
15
+ # test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
16
+ test_list = ["063", "144"]
17
+
18
+ music_dir = "data/finedance/div_by_time/music_npy_120"
19
+ count = 10
20
+
21
+
22
+ def test(opt):
23
+ # split = get_train_test_dict(opt.datasplit)
24
+ train_list, test_list, ignore_list = get_train_test_list(opt.datasplit)
25
+ for file in os.listdir(music_dir):
26
+ if file[:3] in ignore_list:
27
+ continue
28
+ if not file[:3] in test_list:
29
+ continue
30
+
31
+ file_name = file[:-4]
32
+ music_fea = np.load(os.path.join(music_dir, file))
33
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
34
+ music_fea = torch.from_numpy(music_fea).float().to(device).unsqueeze(0)
35
+ music_fea = music_fea.repeat(count, 1, 1)
36
+ all_filenames = [file_name]*count
37
+
38
+ # directory for optionally saving the dances for eval
39
+ fk_out = None
40
+ if opt.save_motions:
41
+ fk_out = opt.motion_save_dir
42
+
43
+ model = EDGE(opt, opt.feature_type, opt.checkpoint)
44
+ model.eval()
45
+
46
+ data_tuple = None, music_fea, all_filenames
47
+ model.render_sample(
48
+ data_tuple, "test", opt.render_dir, render_count=10, mode='normal', fk_out=fk_out, render=not opt.no_render
49
+ )
50
+ print("Done")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ opt = FineDance_parse_test_opt()
55
+ test(opt)
56
+
57
+ # python test.py --save_motions
generate_dance.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end dance generation from a music file.
3
+
4
+ Usage:
5
+ python generate_dance.py /path/to/music.mp3
6
+ python generate_dance.py /path/to/music.mp3 --output my_dance.mp4
7
+ """
8
+
9
+ import argparse
10
+ import glob
11
+ import os
12
+ import subprocess
13
+ import sys
14
+ from functools import cmp_to_key
15
+ from pathlib import Path
16
+ from tempfile import TemporaryDirectory
17
+
18
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
19
+
20
+ import librosa
21
+ import librosa as lr
22
+ import numpy as np
23
+ import soundfile as sf
24
+ import torch
25
+ from tqdm import tqdm
26
+
27
+ from train_seq import EDGE
28
+ from render import MovieMaker, motion_data_load_process
29
+
30
+
31
+ # --- Audio utilities (from test.py) ---
32
+
33
+ def slice_audio(audio_file, stride, length, out_dir):
34
+ audio, sr = lr.load(audio_file, sr=None)
35
+ file_name = os.path.splitext(os.path.basename(audio_file))[0]
36
+ start_idx = 0
37
+ idx = 0
38
+ window = int(length * sr)
39
+ stride_step = int(stride * sr)
40
+ while start_idx <= len(audio) - window:
41
+ audio_slice = audio[start_idx : start_idx + window]
42
+ sf.write(f"{out_dir}/{file_name}_slice{idx}.wav", audio_slice, sr)
43
+ start_idx += stride_step
44
+ idx += 1
45
+ return idx
46
+
47
+
48
+ def extract_features(fpath, full_seq_len=120):
49
+ FPS = 30
50
+ HOP_LENGTH = 512
51
+ SR = FPS * HOP_LENGTH
52
+
53
+ data, _ = librosa.load(fpath, sr=SR)
54
+ envelope = librosa.onset.onset_strength(y=data, sr=SR)
55
+ mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T
56
+ chroma = librosa.feature.chroma_cens(y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12).T
57
+
58
+ peak_idxs = librosa.onset.onset_detect(
59
+ onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
60
+ )
61
+ peak_onehot = np.zeros_like(envelope, dtype=np.float32)
62
+ peak_onehot[peak_idxs] = 1.0
63
+
64
+ start_bpm = lr.beat.tempo(y=lr.load(fpath)[0])[0]
65
+ tempo, beat_idxs = librosa.beat.beat_track(
66
+ onset_envelope=envelope, sr=SR, hop_length=HOP_LENGTH,
67
+ start_bpm=start_bpm, tightness=100,
68
+ )
69
+ beat_onehot = np.zeros_like(envelope, dtype=np.float32)
70
+ beat_onehot[beat_idxs] = 1.0
71
+
72
+ audio_feature = np.concatenate(
73
+ [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
74
+ axis=-1,
75
+ )
76
+ audio_feature = audio_feature[:4 * FPS]
77
+ return audio_feature
78
+
79
+
80
+ key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
81
+
82
+ def stringintcmp_(a, b):
83
+ aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
84
+ ka, kb = key_func(a), key_func(b)
85
+ if aa < bb:
86
+ return -1
87
+ if aa > bb:
88
+ return 1
89
+ if ka < kb:
90
+ return -1
91
+ if ka > kb:
92
+ return 1
93
+ return 0
94
+
95
+ stringintkey = cmp_to_key(stringintcmp_)
96
+
97
+
98
+ # --- Model loading ---
99
+
100
+ class _Opt:
101
+ """Minimal config namespace for EDGE model."""
102
+ feature_type = "baseline"
103
+ full_seq_len = 120
104
+ windows = 10
105
+ nfeats = 319
106
+ do_normalize = False
107
+ datasplit = "cross_genre"
108
+ project = "experiments/finedance_seq_120_genre/train"
109
+ exp_name = "finedance_seq_120_genre"
110
+ render_dir = "tmp_renders"
111
+ batch_size = 64
112
+ epochs = 1
113
+ save_interval = 10
114
+ ema_interval = 1
115
+ checkpoint = ""
116
+ wandb_pj_name = "finedance_seq"
117
+
118
+
119
+ def load_model(checkpoint_path="assets/checkpoints/train-2000.pt"):
120
+ """Load the EDGE model once. Returns (model, opt)."""
121
+ opt = _Opt()
122
+ model = EDGE(opt, opt.feature_type, checkpoint_path)
123
+ model.eval()
124
+ return model, opt
125
+
126
+
127
+ def _setup_render_args():
128
+ """Inject render.py global args for MovieMaker."""
129
+ import render as render_module
130
+ render_module.args = argparse.Namespace(
131
+ mode="smplx", fps=30, gpu="0", modir="", save_path=None
132
+ )
133
+
134
+
135
+ # --- Main pipeline ---
136
+
137
+ def generate(music_path, output_path, model=None, visualizer=None, log_fn=print):
138
+ """
139
+ Generate a dance video from a music file.
140
+
141
+ Args:
142
+ music_path: Path to input audio (mp3, wav, etc.)
143
+ output_path: Where to save the output mp4
144
+ model: Pre-loaded (model, opt) tuple. If None, loads fresh.
145
+ visualizer: Pre-built MovieMaker instance. If None, creates one.
146
+ log_fn: Callable for status messages (default: print)
147
+ """
148
+ import shutil
149
+
150
+ music_path = os.path.abspath(music_path)
151
+ songname = os.path.splitext(os.path.basename(music_path))[0]
152
+
153
+ # Step 1: Convert to WAV if needed
154
+ log_fn(f"[1/5] Preparing audio: {os.path.basename(music_path)}")
155
+ temp_root = TemporaryDirectory()
156
+ wav_dir = os.path.join(temp_root.name, "wav")
157
+ os.makedirs(wav_dir)
158
+ wav_path = os.path.join(wav_dir, songname + ".wav")
159
+
160
+ if music_path.lower().endswith(".wav"):
161
+ shutil.copy2(music_path, wav_path)
162
+ else:
163
+ subprocess.run(
164
+ ["ffmpeg", "-i", music_path, wav_path, "-y"],
165
+ capture_output=True, check=True,
166
+ )
167
+
168
+ # Step 2: Slice and extract features
169
+ log_fn("[2/5] Extracting audio features...")
170
+ slice_dir = os.path.join(temp_root.name, "slices")
171
+ os.makedirs(slice_dir)
172
+ stride = 60 / 30 # 2 seconds
173
+ full_seq_len = 120
174
+ slice_audio(wav_path, stride, full_seq_len / 30, slice_dir)
175
+
176
+ file_list = sorted(glob.glob(f"{slice_dir}/*.wav"), key=stringintkey)
177
+ out_length = 30 # seconds
178
+ sample_size = int(out_length / stride) - 1
179
+
180
+ cond_list = []
181
+ for file in tqdm(file_list[:sample_size]):
182
+ reps = extract_features(file)[:full_seq_len]
183
+ cond_list.append(reps)
184
+ cond = torch.from_numpy(np.array(cond_list))
185
+ filenames = file_list[:sample_size]
186
+
187
+ # Step 3: Generate motion
188
+ log_fn("[3/5] Generating dance motion...")
189
+ if model is None:
190
+ edge_model, opt = load_model()
191
+ else:
192
+ edge_model, opt = model
193
+
194
+ motion_dir = os.path.join(temp_root.name, "motions")
195
+ os.makedirs(motion_dir)
196
+
197
+ data_tuple = (None, cond, filenames)
198
+ edge_model.render_sample(
199
+ data_tuple, "gen", temp_root.name, render_count=-1,
200
+ fk_out=motion_dir, mode="long", render=False,
201
+ )
202
+
203
+ # Step 4: Render video
204
+ log_fn("[4/5] Rendering video...")
205
+ motion_file = glob.glob(os.path.join(motion_dir, "*.pkl"))[0]
206
+ modata = motion_data_load_process(motion_file)
207
+
208
+ video_dir = os.path.join(temp_root.name, "video")
209
+ os.makedirs(video_dir)
210
+
211
+ _setup_render_args()
212
+ if visualizer is None:
213
+ visualizer = MovieMaker(save_path=video_dir)
214
+ else:
215
+ visualizer.save_path = video_dir
216
+ visualizer.run(modata, tab=songname, music_file=wav_path)
217
+
218
+ # Step 5: Copy final output
219
+ log_fn("[5/5] Saving output...")
220
+ rendered_file = os.path.join(video_dir, songname + "z.mp4")
221
+ output_path = os.path.abspath(output_path)
222
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
223
+ shutil.move(rendered_file, output_path)
224
+
225
+ temp_root.cleanup()
226
+ log_fn(f"Done! Output saved to: {output_path}")
227
+ return output_path
228
+
229
+
230
+ if __name__ == "__main__":
231
+ parser = argparse.ArgumentParser(description="Generate a dance video from a music file.")
232
+ parser.add_argument("music", type=str, help="Path to the input music file (mp3, wav, etc.)")
233
+ parser.add_argument("--output", type=str, default=None, help="Output video path (default: output/<songname>_dance.mp4)")
234
+ args = parser.parse_args()
235
+
236
+ if args.output is None:
237
+ songname = os.path.splitext(os.path.basename(args.music))[0]
238
+ args.output = os.path.join("output", f"{songname}_dance.mp4")
239
+
240
+ generate(args.music, args.output)
model/adan.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.optim import Optimizer
5
+
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+
11
+ class Adan(Optimizer):
12
+ def __init__(
13
+ self,
14
+ params,
15
+ lr=1e-3,
16
+ betas=(0.02, 0.08, 0.01),
17
+ eps=1e-8,
18
+ weight_decay=0,
19
+ restart_cond: callable = None,
20
+ ):
21
+ assert len(betas) == 3
22
+
23
+ defaults = dict(
24
+ lr=lr,
25
+ betas=betas,
26
+ eps=eps,
27
+ weight_decay=weight_decay,
28
+ restart_cond=restart_cond,
29
+ )
30
+
31
+ super().__init__(params, defaults)
32
+
33
+ def step(self, closure=None):
34
+ loss = None
35
+
36
+ if exists(closure):
37
+ loss = closure()
38
+
39
+ for group in self.param_groups:
40
+
41
+ lr = group["lr"]
42
+ beta1, beta2, beta3 = group["betas"]
43
+ weight_decay = group["weight_decay"]
44
+ eps = group["eps"]
45
+ restart_cond = group["restart_cond"]
46
+
47
+ for p in group["params"]:
48
+ if not exists(p.grad):
49
+ continue
50
+
51
+ data, grad = p.data, p.grad.data
52
+ assert not grad.is_sparse
53
+
54
+ state = self.state[p]
55
+
56
+ if len(state) == 0:
57
+ state["step"] = 0
58
+ state["prev_grad"] = torch.zeros_like(grad)
59
+ state["m"] = torch.zeros_like(grad)
60
+ state["v"] = torch.zeros_like(grad)
61
+ state["n"] = torch.zeros_like(grad)
62
+
63
+ step, m, v, n, prev_grad = (
64
+ state["step"],
65
+ state["m"],
66
+ state["v"],
67
+ state["n"],
68
+ state["prev_grad"],
69
+ )
70
+
71
+ if step > 0:
72
+ prev_grad = state["prev_grad"]
73
+
74
+ # main algorithm
75
+
76
+ m.mul_(1 - beta1).add_(grad, alpha=beta1)
77
+
78
+ grad_diff = grad - prev_grad
79
+
80
+ v.mul_(1 - beta2).add_(grad_diff, alpha=beta2)
81
+
82
+ next_n = (grad + (1 - beta2) * grad_diff) ** 2
83
+
84
+ n.mul_(1 - beta3).add_(next_n, alpha=beta3)
85
+
86
+ # bias correction terms
87
+
88
+ step += 1
89
+
90
+ correct_m, correct_v, correct_n = map(
91
+ lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3)
92
+ )
93
+
94
+ # gradient step
95
+
96
+ def grad_step_(data, m, v, n):
97
+ weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)
98
+
99
+ denom = 1 + weight_decay * lr
100
+
101
+ data.addcmul_(
102
+ weighted_step_size,
103
+ (m * correct_m + (1 - beta2) * v * correct_v),
104
+ value=-1.0,
105
+ ).div_(denom)
106
+
107
+ grad_step_(data, m, v, n)
108
+
109
+ # restart condition
110
+
111
+ if exists(restart_cond) and restart_cond(state):
112
+ m.data.copy_(grad)
113
+ v.zero_()
114
+ n.data.copy_(grad ** 2)
115
+
116
+ grad_step_(data, m, v, n)
117
+
118
+ # set new incremented step
119
+
120
+ prev_grad.copy_(grad)
121
+ state["step"] = step
122
+
123
+ return loss
model/diffusion.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import pickle
4
+ from pathlib import Path
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import reduce
12
+ from p_tqdm import p_map
13
+ from pytorch3d.transforms import (axis_angle_to_quaternion,
14
+ quaternion_to_axis_angle)
15
+ from tqdm import tqdm
16
+
17
+ from dataset.quaternion import ax_from_6v, quat_slerp
18
+ from vis import skeleton_render
19
+ from vis import SMPLX_Skeleton
20
+ from dataset.preprocess import My_Normalizer as Normalizer
21
+
22
+ from .utils import extract, make_beta_schedule
23
+
24
+ def identity(t, *args, **kwargs):
25
+ return t
26
+
27
+ class EMA:
28
+ def __init__(self, beta):
29
+ super().__init__()
30
+ self.beta = beta
31
+
32
+ def update_model_average(self, ma_model, current_model):
33
+ for current_params, ma_params in zip(
34
+ current_model.parameters(), ma_model.parameters()
35
+ ):
36
+ old_weight, up_weight = ma_params.data, current_params.data
37
+ ma_params.data = self.update_average(old_weight, up_weight)
38
+
39
+ def update_average(self, old, new):
40
+ if old is None:
41
+ return new
42
+ return old * self.beta + (1 - self.beta) * new
43
+
44
+
45
+ class GaussianDiffusion(nn.Module):
46
+ def __init__(
47
+ self,
48
+ model,
49
+ opt,
50
+ horizon,
51
+ repr_dim,
52
+ smplx_model,
53
+ n_timestep=1000,
54
+ schedule="linear",
55
+ loss_type="l1",
56
+ clip_denoised=True,
57
+ predict_epsilon=True,
58
+ guidance_weight=3,
59
+ use_p2=False,
60
+ cond_drop_prob=0.2,
61
+ do_normalize=False,
62
+ ):
63
+ super().__init__()
64
+ self.horizon = horizon
65
+ self.transition_dim = repr_dim
66
+ self.model = model
67
+ self.ema = EMA(0.9999)
68
+ self.master_model = copy.deepcopy(self.model)
69
+ self.normalizer = None
70
+ self.do_normalize = do_normalize
71
+ self.opt = opt
72
+
73
+ self.cond_drop_prob = cond_drop_prob
74
+
75
+ # make a SMPL instance for FK module
76
+ self.smplx_fk = smplx_model
77
+
78
+ betas = torch.Tensor(
79
+ make_beta_schedule(schedule=schedule, n_timestep=n_timestep)
80
+ )
81
+ alphas = 1.0 - betas
82
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
83
+ alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
84
+
85
+ self.n_timestep = int(n_timestep)
86
+ self.clip_denoised = clip_denoised
87
+ self.predict_epsilon = predict_epsilon
88
+
89
+ self.register_buffer("betas", betas)
90
+ self.register_buffer("alphas_cumprod", alphas_cumprod)
91
+ self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
92
+
93
+ self.guidance_weight = guidance_weight
94
+
95
+ # calculations for diffusion q(x_t | x_{t-1}) and others
96
+ self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
97
+ self.register_buffer(
98
+ "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
99
+ )
100
+ self.register_buffer(
101
+ "log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
102
+ )
103
+ self.register_buffer(
104
+ "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
105
+ )
106
+ self.register_buffer(
107
+ "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
108
+ )
109
+
110
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
111
+ posterior_variance = (
112
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
113
+ )
114
+ self.register_buffer("posterior_variance", posterior_variance)
115
+
116
+ ## log calculation clipped because the posterior variance
117
+ ## is 0 at the beginning of the diffusion chain
118
+ self.register_buffer(
119
+ "posterior_log_variance_clipped",
120
+ torch.log(torch.clamp(posterior_variance, min=1e-20)),
121
+ )
122
+ self.register_buffer(
123
+ "posterior_mean_coef1",
124
+ betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
125
+ )
126
+ self.register_buffer(
127
+ "posterior_mean_coef2",
128
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
129
+ )
130
+
131
+ # p2 weighting
132
+ self.p2_loss_weight_k = 1
133
+ self.p2_loss_weight_gamma = 0.5 if use_p2 else 0
134
+ self.register_buffer(
135
+ "p2_loss_weight",
136
+ (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
137
+ ** -self.p2_loss_weight_gamma,
138
+ )
139
+
140
+ ## get loss coefficients and initialize objective
141
+ self.loss_fn = F.mse_loss if loss_type == "l2" else F.l1_loss
142
+
143
+ # ------------------------------------------ sampling ------------------------------------------#
144
+
145
+ def predict_start_from_noise(self, x_t, t, noise):
146
+ """
147
+ if self.predict_epsilon, model output is (scaled) noise;
148
+ otherwise, model predicts x0 directly
149
+ """
150
+ if self.predict_epsilon:
151
+ return (
152
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
153
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
154
+ )
155
+ else:
156
+ return noise
157
+
158
+ def predict_noise_from_start(self, x_t, t, x0):
159
+ return (
160
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
161
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
162
+ )
163
+
164
+ def model_predictions(self, x, cond, t, weight=None, clip_x_start = False):
165
+ weight = weight if weight is not None else self.guidance_weight
166
+ model_output = self.model.guided_forward(x, cond, t, weight)
167
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
168
+
169
+ x_start = model_output
170
+ x_start = maybe_clip(x_start)
171
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
172
+
173
+ return pred_noise, x_start
174
+
175
+ def q_posterior(self, x_start, x_t, t):
176
+ posterior_mean = (
177
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
178
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
179
+ )
180
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
181
+ posterior_log_variance_clipped = extract(
182
+ self.posterior_log_variance_clipped, t, x_t.shape
183
+ )
184
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
185
+
186
+ def p_mean_variance(self, x, cond, t):
187
+ # guidance clipping
188
+ if t[0] > 1.0 * self.n_timestep:
189
+ weight = min(self.guidance_weight, 0)
190
+ elif t[0] < 0.1 * self.n_timestep:
191
+ weight = min(self.guidance_weight, 1)
192
+ else:
193
+ weight = self.guidance_weight
194
+
195
+ x_recon = self.predict_start_from_noise(
196
+ x, t=t, noise=self.model.guided_forward(x, cond, t, weight)
197
+ )
198
+
199
+ if self.clip_denoised:
200
+ x_recon.clamp_(-1.0, 1.0)
201
+ else:
202
+ assert RuntimeError()
203
+
204
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
205
+ x_start=x_recon, x_t=x, t=t
206
+ )
207
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
208
+
209
+ @torch.no_grad()
210
+ def p_sample(self, x, cond, t):
211
+ b, *_, device = *x.shape, x.device
212
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(
213
+ x=x, cond=cond, t=t
214
+ )
215
+ noise = torch.randn_like(model_mean)
216
+ # no noise when t == 0
217
+ nonzero_mask = (1 - (t == 0).float()).reshape(
218
+ b, *((1,) * (len(noise.shape) - 1))
219
+ )
220
+ x_out = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
221
+ return x_out, x_start
222
+
223
+ @torch.no_grad()
224
+ def p_sample_loop(
225
+ self,
226
+ shape,
227
+ cond,
228
+ noise=None,
229
+ constraint=None,
230
+ return_diffusion=False,
231
+ start_point=None,
232
+ ):
233
+ device = self.betas.device
234
+
235
+ # default to diffusion over whole timescale
236
+ start_point = self.n_timestep if start_point is None else start_point
237
+ batch_size = shape[0]
238
+ x = torch.randn(shape, device=device) if noise is None else noise.to(device)
239
+ cond = cond.to(device)
240
+
241
+ if return_diffusion:
242
+ diffusion = [x]
243
+
244
+ for i in tqdm(reversed(range(0, start_point))):
245
+ # fill with i
246
+ timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
247
+ x, _ = self.p_sample(x, cond, timesteps)
248
+
249
+ if return_diffusion:
250
+ diffusion.append(x)
251
+
252
+ if return_diffusion:
253
+ return x, diffusion
254
+ else:
255
+ return x
256
+
257
+ @torch.no_grad()
258
+ def ddim_sample(self, shape, cond, **kwargs):
259
+ batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
260
+
261
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
262
+ times = list(reversed(times.int().tolist()))
263
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
264
+
265
+ x = torch.randn(shape, device = device)
266
+ cond = cond.to(device)
267
+
268
+ x_start = None
269
+
270
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
271
+ time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
272
+ pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start = self.clip_denoised)
273
+
274
+ if time_next < 0:
275
+ x = x_start
276
+ continue
277
+
278
+ alpha = self.alphas_cumprod[time]
279
+ alpha_next = self.alphas_cumprod[time_next]
280
+
281
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
282
+ c = (1 - alpha_next - sigma ** 2).sqrt()
283
+
284
+ noise = torch.randn_like(x)
285
+
286
+ x = x_start * alpha_next.sqrt() + \
287
+ c * pred_noise + \
288
+ sigma * noise
289
+ return x
290
+
291
+ @torch.no_grad()
292
+ def long_ddim_sample(self, shape, cond, **kwargs):
293
+ batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
294
+
295
+ if batch == 1:
296
+ return self.ddim_sample(shape, cond)
297
+
298
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
299
+ times = list(reversed(times.int().tolist()))
300
+ weights = np.clip(np.linspace(0, self.guidance_weight * 2, sampling_timesteps), None, self.guidance_weight)
301
+ time_pairs = list(zip(times[:-1], times[1:], weights)) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
302
+
303
+ x = torch.randn(shape, device = device)
304
+ cond = cond.to(device)
305
+
306
+ assert batch > 1
307
+ assert x.shape[1] % 2 == 0
308
+ half = x.shape[1] // 2
309
+
310
+ x_start = None
311
+
312
+ for time, time_next, weight in tqdm(time_pairs, desc = 'sampling loop time step'):
313
+ time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
314
+ pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, weight=weight, clip_x_start = self.clip_denoised)
315
+
316
+ if time_next < 0:
317
+ x = x_start
318
+ continue
319
+
320
+ alpha = self.alphas_cumprod[time]
321
+ alpha_next = self.alphas_cumprod[time_next]
322
+
323
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
324
+ c = (1 - alpha_next - sigma ** 2).sqrt()
325
+
326
+ noise = torch.randn_like(x)
327
+
328
+ x = x_start * alpha_next.sqrt() + \
329
+ c * pred_noise + \
330
+ sigma * noise
331
+
332
+ if time > 0:
333
+ # the first half of each sequence is the second half of the previous one
334
+ x[1:, :half] = x[:-1, half:]
335
+ return x
336
+
337
+ @torch.no_grad()
338
+ def inpaint_loop(
339
+ self,
340
+ shape,
341
+ cond,
342
+ noise=None,
343
+ constraint=None,
344
+ return_diffusion=False,
345
+ start_point=None,
346
+ ):
347
+ device = self.betas.device
348
+
349
+ batch_size = shape[0]
350
+ x = torch.randn(shape, device=device) if noise is None else noise.to(device)
351
+ cond = cond.to(device)
352
+ if return_diffusion:
353
+ diffusion = [x]
354
+
355
+ mask = constraint["mask"].to(device) # batch x horizon x channels
356
+ value = constraint["value"].to(device) # batch x horizon x channels
357
+
358
+ start_point = self.n_timestep if start_point is None else start_point
359
+ for i in tqdm(reversed(range(0, start_point))):
360
+ # fill with i
361
+ timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
362
+
363
+ # sample x from step i to step i-1
364
+ x, _ = self.p_sample(x, cond, timesteps)
365
+ # enforce constraint between each denoising step
366
+ value_ = self.q_sample(value, timesteps - 1) if (i > 0) else x
367
+ x = value_ * mask + (1.0 - mask) * x
368
+
369
+ if return_diffusion:
370
+ diffusion.append(x)
371
+
372
+ if return_diffusion:
373
+ return x, diffusion
374
+ else:
375
+ return x
376
+
377
+ @torch.no_grad()
378
+ def long_inpaint_loop(
379
+ self,
380
+ shape,
381
+ cond,
382
+ noise=None,
383
+ constraint=None,
384
+ return_diffusion=False,
385
+ start_point=None,
386
+ ):
387
+ device = self.betas.device
388
+
389
+ batch_size = shape[0]
390
+ x = torch.randn(shape, device=device) if noise is None else noise.to(device)
391
+ cond = cond.to(device)
392
+ if return_diffusion:
393
+ diffusion = [x]
394
+
395
+ assert x.shape[1] % 2 == 0
396
+ if batch_size == 1:
397
+ # there's no continuation to do, just do normal
398
+ return self.p_sample_loop(
399
+ shape,
400
+ cond,
401
+ noise=noise,
402
+ constraint=constraint,
403
+ return_diffusion=return_diffusion,
404
+ start_point=start_point,
405
+ )
406
+ assert batch_size > 1
407
+ half = x.shape[1] // 2
408
+
409
+ start_point = self.n_timestep if start_point is None else start_point
410
+ for i in tqdm(reversed(range(0, start_point))):
411
+ # fill with i
412
+ timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
413
+
414
+ # sample x from step i to step i-1
415
+ x, _ = self.p_sample(x, cond, timesteps)
416
+ # enforce constraint between each denoising step
417
+ if i > 0:
418
+ # the first half of each sequence is the second half of the previous one
419
+ x[1:, :half] = x[:-1, half:]
420
+
421
+ if return_diffusion:
422
+ diffusion.append(x)
423
+
424
+ if return_diffusion:
425
+ return x, diffusion
426
+ else:
427
+ return x
428
+
429
+ @torch.no_grad()
430
+ def conditional_sample(
431
+ self, shape, cond, constraint=None, *args, horizon=None, **kwargs
432
+ ):
433
+ """
434
+ conditions : [ (time, state), ... ]
435
+ """
436
+ device = self.betas.device
437
+ horizon = horizon or self.horizon
438
+
439
+ return self.p_sample_loop(shape, cond, *args, **kwargs)
440
+
441
+ # ------------------------------------------ training ------------------------------------------#
442
+
443
+ def q_sample(self, x_start, t, noise=None):
444
+ if noise is None:
445
+ noise = torch.randn_like(x_start)
446
+
447
+ sample = (
448
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
449
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
450
+ )
451
+
452
+ return sample
453
+
454
+ def p_losses(self, x_start, cond, t):
455
+ noise = torch.randn_like(x_start)
456
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 将x0加噪到xt
457
+
458
+ # reconstruct
459
+ x_recon = self.model(x_noisy, cond, t, cond_drop_prob=self.cond_drop_prob)
460
+ assert noise.shape == x_recon.shape
461
+
462
+ model_out = x_recon
463
+ if self.predict_epsilon:
464
+ target = noise
465
+ else:
466
+ target = x_start
467
+
468
+ # full reconstruction loss
469
+ loss = self.loss_fn(model_out, target, reduction="none") # mse loss
470
+ loss = reduce(loss, "b ... -> b (...)", "mean")
471
+ loss = loss * extract(self.p2_loss_weight, t, loss.shape)
472
+
473
+ # split off contact from the rest
474
+ _, model_out_ = torch.split(
475
+ model_out, (4, model_out.shape[2] - 4), dim=2 # 前4维是foot contact
476
+ )
477
+ _, target_ = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc
478
+
479
+ # velocity loss
480
+ target_v = target_[:, 1:] - target_[:, :-1]
481
+ model_out_v = model_out_[:, 1:] - model_out_[:, :-1]
482
+ v_loss = self.loss_fn(model_out_v, target_v, reduction="none")
483
+ v_loss = reduce(v_loss, "b ... -> b (...)", "mean")
484
+ v_loss = v_loss * extract(self.p2_loss_weight, t, v_loss.shape)
485
+
486
+ # FK loss
487
+ b, s, c = model_out.shape
488
+ model_contact, model_out = torch.split(model_out, (4, model_out.shape[2] - 4), dim=2)
489
+ target_contact, target = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc
490
+ model_x = model_out[:, :, :3] # root position
491
+ model_q = ax_from_6v(model_out[:, :, 3:].reshape(b, s, -1, 6))
492
+ target_x = target[:, :, :3]
493
+ target_q = ax_from_6v(target[:, :, 3:].reshape(b, s, -1, 6))
494
+ b, s, nums, c_ = model_q.shape
495
+
496
+ if self.opt.nfeats == 139 or self.opt.nfeats==135:
497
+ model_xp = self.smplx_fk.forward(model_q, model_x)
498
+ target_xp = self.smplx_fk.forward(target_q, target_x)
499
+ else:
500
+ model_q = model_q.view(b*s, -1)
501
+ target_q = target_q.view(b*s, -1)
502
+ model_x = model_x.view(-1, 3)
503
+ target_x = target_x.view(-1, 3)
504
+ model_xp = self.smplx_fk.forward(model_q, model_x)
505
+ target_xp = self.smplx_fk.forward(target_q, target_x)
506
+ model_xp = model_xp.view(b, s, -1, 3)
507
+ target_xp = target_xp.view(b, s, -1, 3)
508
+
509
+
510
+
511
+ fk_loss = self.loss_fn(model_xp, target_xp, reduction="none")
512
+ fk_loss = reduce(fk_loss, "b ... -> b (...)", "mean")
513
+ fk_loss = fk_loss * extract(self.p2_loss_weight, t, fk_loss.shape)
514
+
515
+ # foot skate loss
516
+ foot_idx = [7, 8, 10, 11]
517
+ # find static indices consistent with model's own predictions
518
+ static_idx = model_contact > 0.95 # N x S x 4
519
+ model_feet = model_xp[:, :, foot_idx] # foot positions (N, S, 4, 3)
520
+ model_foot_v = torch.zeros_like(model_feet)
521
+ model_foot_v[:, :-1] = (
522
+ model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :]
523
+ ) # (N, S-1, 4, 3)
524
+ model_foot_v[~static_idx] = 0
525
+ foot_loss = self.loss_fn(
526
+ model_foot_v, torch.zeros_like(model_foot_v), reduction="none"
527
+ )
528
+ foot_loss = reduce(foot_loss, "b ... -> b (...)", "mean")
529
+ losses = (
530
+ 0.636 * loss.mean(),
531
+ 2.964 * v_loss.mean(),
532
+ 0.646 * fk_loss.mean(),
533
+ 10.942 * foot_loss.mean(),
534
+ )
535
+
536
+ return sum(losses), losses
537
+
538
+ def loss(self, x, cond, t_override=None):
539
+ batch_size = len(x)
540
+ if t_override is None:
541
+ t = torch.randint(0, self.n_timestep, (batch_size,), device=x.device).long()
542
+ else:
543
+ t = torch.full((batch_size,), t_override, device=x.device).long()
544
+ return self.p_losses(x, cond, t)
545
+
546
+ def forward(self, x, cond, t_override=None):
547
+ return self.loss(x, cond, t_override)
548
+
549
+ def partial_denoise(self, x, cond, t):
550
+ x_noisy = self.noise_to_t(x, t)
551
+ return self.p_sample_loop(x.shape, cond, noise=x_noisy, start_point=t)
552
+
553
+ def noise_to_t(self, x, timestep):
554
+ batch_size = len(x)
555
+ t = torch.full((batch_size,), timestep, device=x.device).long()
556
+ return self.q_sample(x, t) if timestep > 0 else x
557
+
558
+ def smplxmodel_fk(self, local_q, root_pos): # input
559
+ b, s, nums, c = local_q.shape
560
+ local_q = local_q.view(b*s, -1)
561
+ full_pose = self.smplx_model(
562
+ betas = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32),
563
+ transl = root_pos.view(b*s, -1), # global translation
564
+ global_orient = local_q[:, :3],
565
+ body_pose = local_q[:, 3:66], # 21
566
+ jaw_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
567
+ leye_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
568
+ reye_pose= torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
569
+ left_hand_pose = local_q[:, 66:111], # 15
570
+ right_hand_pose = local_q[:, 111:], # 15
571
+ expression = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32),
572
+ return_verts = False
573
+ )
574
+ full_pose = full_pose.joints.view(b, s, -1, 3) # b, s, 55, 3
575
+ return full_pose
576
+
577
+
578
+ def render_sample(
579
+ self,
580
+ shape,
581
+ cond,
582
+ normalizer,
583
+ epoch,
584
+ render_out,
585
+ fk_out=None,
586
+ name=None,
587
+ sound=True,
588
+ mode="normal",
589
+ noise=None,
590
+ constraint=None,
591
+ sound_folder="ood_sliced",
592
+ start_point=None,
593
+ render=True,
594
+ # do_normalize=True,
595
+ ):
596
+ if isinstance(shape, tuple):
597
+ if mode == "inpaint":
598
+ func_class = self.inpaint_loop
599
+ elif mode == "normal":
600
+ func_class = self.ddim_sample
601
+ elif mode == "long":
602
+ func_class = self.long_ddim_sample
603
+ else:
604
+ assert False, "Unrecognized inference mode"
605
+ samples = (
606
+ func_class(
607
+ shape,
608
+ cond,
609
+ noise=noise,
610
+ constraint=constraint,
611
+ start_point=start_point,
612
+ )
613
+ .detach()
614
+ .cpu()
615
+ )
616
+ else:
617
+ samples = shape
618
+
619
+ if self.do_normalize:
620
+ with torch.no_grad():
621
+ samples = normalizer.unnormalize(samples)
622
+
623
+ if samples.shape[2] == 319 or samples.shape[2] == 151 or samples.shape[2] == 139: # debug if samples.shape[2] == 151:
624
+ sample_contact, samples = torch.split(
625
+ samples, (4, samples.shape[2] - 4), dim=2
626
+ )
627
+ else:
628
+ sample_contact = None
629
+ # do the FK all at once
630
+ b, s, c = samples.shape
631
+ pos = samples[:, :, :3].to(cond.device) # np.zeros((sample.shape[0], 3))
632
+ q = samples[:, :, 3:].reshape(b, s, -1, 6) # debug 24
633
+ # go 6d to ax
634
+ q = ax_from_6v(q).to(cond.device)
635
+
636
+ if self.opt.nfeats == 139 or self.opt.nfeats==135:
637
+ reshape_size = 66
638
+ else:
639
+ reshape_size = 156
640
+
641
+ if mode == "long":
642
+ b, s, c1, c2 = q.shape
643
+ assert s % 2 == 0
644
+ half = s // 2
645
+ if b > 1:
646
+ # if long mode, stitch position using linear interp
647
+ fade_out = torch.ones((1, s, 1)).to(pos.device)
648
+ fade_in = torch.ones((1, s, 1)).to(pos.device)
649
+ fade_out[:, half:, :] = torch.linspace(1, 0, half)[None, :, None].to(
650
+ pos.device
651
+ )
652
+ fade_in[:, :half, :] = torch.linspace(0, 1, half)[None, :, None].to(
653
+ pos.device
654
+ )
655
+
656
+ pos[:-1] *= fade_out
657
+ pos[1:] *= fade_in
658
+
659
+ full_pos = torch.zeros((s + half * (b - 1), 3)).to(pos.device)
660
+ idx = 0
661
+ for pos_slice in pos:
662
+ full_pos[idx : idx + s] += pos_slice
663
+ idx += half
664
+
665
+ # stitch joint angles with slerp
666
+ slerp_weight = torch.linspace(0, 1, half)[None, :, None].to(pos.device)
667
+
668
+ left, right = q[:-1, half:], q[1:, :half]
669
+ # convert to quat
670
+ left, right = (
671
+ axis_angle_to_quaternion(left),
672
+ axis_angle_to_quaternion(right),
673
+ )
674
+ merged = quat_slerp(left, right, slerp_weight) # (b-1) x half x ...
675
+ # convert back
676
+ merged = quaternion_to_axis_angle(merged)
677
+
678
+ full_q = torch.zeros((s + half * (b - 1), c1, c2)).to(pos.device)
679
+ full_q[:half] += q[0, :half]
680
+ idx = half
681
+ for q_slice in merged:
682
+ full_q[idx : idx + half] += q_slice
683
+ idx += half
684
+ full_q[idx : idx + half] += q[-1, half:]
685
+
686
+ # unsqueeze for fk
687
+ full_pos = full_pos.unsqueeze(0)
688
+ full_q = full_q.unsqueeze(0)
689
+ else:
690
+ full_pos = pos
691
+ full_q = q
692
+
693
+
694
+ if fk_out is not None:
695
+ outname = f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.pkl' # f'{epoch}_{"_".join(name)}.pkl' #
696
+ Path(fk_out).mkdir(parents=True, exist_ok=True)
697
+ pickle.dump(
698
+ {
699
+ "smpl_poses": full_q.squeeze(0).reshape((-1, reshape_size)).cpu().numpy(), # local rotations
700
+ "smpl_trans": full_pos.squeeze(0).cpu().numpy(), # root translation
701
+ # "full_pose": full_pose[0], # 3d positions
702
+ },
703
+ open(os.path.join(fk_out, outname), "wb"),
704
+ )
705
+ return
706
+
707
+
708
+ sample_contact = (
709
+ sample_contact.detach().cpu().numpy()
710
+ if sample_contact is not None
711
+ else None
712
+ )
713
+ def inner(xx):
714
+ num, pose = xx
715
+ filename = name[num] if name is not None else None
716
+ contact = sample_contact[num] if sample_contact is not None else None
717
+ skeleton_render(
718
+ pose,
719
+ epoch=f"e{epoch}_b{num}",
720
+ out=render_out,
721
+ name=filename,
722
+ sound=sound,
723
+ contact=contact,
724
+ )
725
+
726
+ # p_map(inner, enumerate(poses)) # poses: 2, 150, 52, 3
727
+ # print("4")
728
+ if fk_out is not None and mode != "long":
729
+ Path(fk_out).mkdir(parents=True, exist_ok=True)
730
+ # for num, (qq, pos_, filename, pose) in enumerate(zip(q, pos, name, poses)):
731
+ for num, (qq, pos_, filename) in enumerate(zip(q, pos, name)):
732
+ filename = os.path.basename(filename).split(".")[0]
733
+ outname = f"{epoch}_{num}_{filename}.pkl"
734
+ pickle.dump(
735
+ {
736
+ "smpl_poses": qq.reshape((-1, reshape_size)).cpu().numpy(),
737
+ "smpl_trans": pos_.cpu().numpy(),
738
+ # "full_pose": pose,
739
+ },
740
+ open(f"{fk_out}/{outname}", "wb"),
741
+ )
model/model.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, reduce, repeat
7
+ from einops.layers.torch import Rearrange, Reduce
8
+ from torch import Tensor
9
+ from torch.nn import functional as F
10
+
11
+ from model.rotary_embedding_torch import RotaryEmbedding
12
+ from model.utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
13
+
14
+
15
+ class DenseFiLM(nn.Module):
16
+ """Feature-wise linear modulation (FiLM) generator."""
17
+
18
+ def __init__(self, embed_channels):
19
+ super().__init__()
20
+ self.embed_channels = embed_channels
21
+ self.block = nn.Sequential(
22
+ nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
23
+ )
24
+
25
+ def forward(self, position):
26
+ pos_encoding = self.block(position)
27
+ pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
28
+ scale_shift = pos_encoding.chunk(2, dim=-1)
29
+ return scale_shift
30
+
31
+
32
+ def featurewise_affine(x, scale_shift):
33
+ scale, shift = scale_shift
34
+ return (scale + 1) * x + shift
35
+
36
+
37
+ class TransformerEncoderLayer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ d_model: int,
41
+ nhead: int,
42
+ dim_feedforward: int = 2048,
43
+ dropout: float = 0.1,
44
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
45
+ layer_norm_eps: float = 1e-5,
46
+ batch_first: bool = False,
47
+ norm_first: bool = True,
48
+ device=None,
49
+ dtype=None,
50
+ rotary=None,
51
+ ) -> None:
52
+ super().__init__()
53
+ self.self_attn = nn.MultiheadAttention(
54
+ d_model, nhead, dropout=dropout, batch_first=batch_first
55
+ )
56
+ # Implementation of Feedforward model
57
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
58
+ self.dropout = nn.Dropout(dropout)
59
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
60
+
61
+ self.norm_first = norm_first
62
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
63
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
64
+ self.dropout1 = nn.Dropout(dropout)
65
+ self.dropout2 = nn.Dropout(dropout)
66
+ self.activation = activation
67
+
68
+ self.rotary = rotary
69
+ self.use_rotary = rotary is not None
70
+
71
+ def forward(
72
+ self,
73
+ src: Tensor,
74
+ src_mask: Optional[Tensor] = None,
75
+ src_key_padding_mask: Optional[Tensor] = None,
76
+ ) -> Tensor:
77
+ x = src
78
+ if self.norm_first:
79
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
80
+ x = x + self._ff_block(self.norm2(x))
81
+ else:
82
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
83
+ x = self.norm2(x + self._ff_block(x))
84
+
85
+ return x
86
+
87
+ # self-attention block
88
+ def _sa_block(
89
+ self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
90
+ ) -> Tensor:
91
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
92
+ x = self.self_attn(
93
+ qk,
94
+ qk,
95
+ x,
96
+ attn_mask=attn_mask,
97
+ key_padding_mask=key_padding_mask,
98
+ need_weights=False,
99
+ )[0]
100
+ return self.dropout1(x)
101
+
102
+ # feed forward block
103
+ def _ff_block(self, x: Tensor) -> Tensor:
104
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
105
+ return self.dropout2(x)
106
+
107
+
108
+ class FiLMTransformerDecoderLayer(nn.Module):
109
+ def __init__(
110
+ self,
111
+ d_model: int,
112
+ nhead: int,
113
+ dim_feedforward=2048,
114
+ dropout=0.1,
115
+ activation=F.relu,
116
+ layer_norm_eps=1e-5,
117
+ batch_first=False,
118
+ norm_first=True,
119
+ device=None,
120
+ dtype=None,
121
+ rotary=None,
122
+ ):
123
+ super().__init__()
124
+ self.self_attn = nn.MultiheadAttention(
125
+ d_model, nhead, dropout=dropout, batch_first=batch_first
126
+ )
127
+ self.multihead_attn = nn.MultiheadAttention(
128
+ d_model, nhead, dropout=dropout, batch_first=batch_first
129
+ )
130
+ # Feedforward
131
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
132
+ self.dropout = nn.Dropout(dropout)
133
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
134
+
135
+ self.norm_first = norm_first
136
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
137
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
138
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
139
+ self.dropout1 = nn.Dropout(dropout)
140
+ self.dropout2 = nn.Dropout(dropout)
141
+ self.dropout3 = nn.Dropout(dropout)
142
+ self.activation = activation
143
+
144
+ self.film1 = DenseFiLM(d_model)
145
+ self.film2 = DenseFiLM(d_model)
146
+ self.film3 = DenseFiLM(d_model)
147
+
148
+ self.rotary = rotary
149
+ self.use_rotary = rotary is not None
150
+
151
+ # x, cond, t
152
+ def forward(
153
+ self,
154
+ tgt,
155
+ memory,
156
+ t,
157
+ tgt_mask=None,
158
+ memory_mask=None,
159
+ tgt_key_padding_mask=None,
160
+ memory_key_padding_mask=None,
161
+ ):
162
+ x = tgt
163
+ if self.norm_first:
164
+ # self-attention -> film -> residual
165
+ x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
166
+ x = x + featurewise_affine(x_1, self.film1(t))
167
+ # cross-attention -> film -> residual
168
+ x_2 = self._mha_block(
169
+ self.norm2(x), memory, memory_mask, memory_key_padding_mask
170
+ )
171
+ x = x + featurewise_affine(x_2, self.film2(t))
172
+ # feedforward -> film -> residual
173
+ x_3 = self._ff_block(self.norm3(x))
174
+ x = x + featurewise_affine(x_3, self.film3(t))
175
+ else:
176
+ x = self.norm1(
177
+ x
178
+ + featurewise_affine(
179
+ self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
180
+ )
181
+ )
182
+ x = self.norm2(
183
+ x
184
+ + featurewise_affine(
185
+ self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
186
+ self.film2(t),
187
+ )
188
+ )
189
+ x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
190
+ return x
191
+
192
+ # self-attention block
193
+ # qkv
194
+ def _sa_block(self, x, attn_mask, key_padding_mask):
195
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
196
+ x = self.self_attn(
197
+ qk,
198
+ qk,
199
+ x,
200
+ attn_mask=attn_mask,
201
+ key_padding_mask=key_padding_mask,
202
+ need_weights=False,
203
+ )[0]
204
+ return self.dropout1(x)
205
+
206
+ # multihead attention block
207
+ # qkv
208
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
209
+ q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
210
+ k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
211
+ x = self.multihead_attn(
212
+ q,
213
+ k,
214
+ mem,
215
+ attn_mask=attn_mask,
216
+ key_padding_mask=key_padding_mask,
217
+ need_weights=False,
218
+ )[0]
219
+ return self.dropout2(x)
220
+
221
+ # feed forward block
222
+ def _ff_block(self, x):
223
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
224
+ return self.dropout3(x)
225
+
226
+
227
+ class DecoderLayerStack(nn.Module):
228
+ def __init__(self, stack):
229
+ super().__init__()
230
+ self.stack = stack
231
+
232
+ def forward(self, x, cond, t):
233
+ for layer in self.stack:
234
+ x = layer(x, cond, t)
235
+ return x
236
+
237
+
238
+
239
+ class SeqModel(nn.Module):
240
+ def __init__(self,
241
+ nfeats: int,
242
+ seq_len: int = 150, # 5 seconds, 30 fps
243
+ latent_dim: int = 256,
244
+ ff_size: int = 1024,
245
+ num_layers: int = 4,
246
+ num_heads: int = 4,
247
+ dropout: float = 0.1,
248
+ cond_feature_dim: int = 35,
249
+ activation: Callable[[Tensor], Tensor] = F.gelu,
250
+ use_rotary=True,
251
+ **kwargs
252
+ ) -> None:
253
+ super().__init__()
254
+
255
+ self.network = nn.ModuleDict()
256
+ self.network['body_net'] = DanceDecoder(
257
+ nfeats=4+3+22*6,
258
+ seq_len=seq_len,
259
+ latent_dim=latent_dim,
260
+ ff_size=ff_size,
261
+ num_layers=num_layers,
262
+ num_heads=num_heads,
263
+ dropout=dropout,
264
+ cond_feature_dim=cond_feature_dim,
265
+ activation=activation
266
+ )
267
+ self.network['hand_net'] = DanceDecoder(
268
+ nfeats=30*6,
269
+ seq_len=seq_len,
270
+ latent_dim=latent_dim,
271
+ ff_size=ff_size,
272
+ num_layers=num_layers,
273
+ num_heads=num_heads,
274
+ dropout=dropout,
275
+ cond_feature_dim=35+139, # debug !
276
+ activation=activation
277
+ )
278
+
279
+
280
+ def forward(self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0):
281
+ x_body_start = x[:,:,:4+135]
282
+ x_hand_start = x[:,:,4+135:]
283
+ body_output = self.network['body_net'](x_body_start, cond_embed, times, cond_drop_prob)
284
+
285
+ cond_embed = torch.cat([body_output, cond_embed], dim = -1)
286
+ hand_output = self.network['hand_net'](x_hand_start, cond_embed, times, cond_drop_prob)
287
+
288
+ output = torch.cat([body_output, hand_output], dim=-1)
289
+ return output
290
+
291
+ def guided_forward(self, x, cond_embed, times, guidance_weight):
292
+ unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
293
+
294
+ conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
295
+ return unc + (conditioned - unc) * guidance_weight
296
+
297
+
298
+ class DanceDecoder(nn.Module):
299
+ def __init__(
300
+ self,
301
+ nfeats: int,
302
+ seq_len: int = 150, # 5 seconds, 30 fps
303
+ latent_dim: int = 256,
304
+ ff_size: int = 1024,
305
+ num_layers: int = 4,
306
+ num_heads: int = 4,
307
+ dropout: float = 0.1,
308
+ cond_feature_dim: int = 35,
309
+ activation: Callable[[Tensor], Tensor] = F.gelu,
310
+ use_rotary=True,
311
+ **kwargs
312
+ ) -> None:
313
+
314
+ super().__init__()
315
+
316
+ output_feats = nfeats
317
+
318
+ # positional embeddings
319
+ self.rotary = None
320
+ self.abs_pos_encoding = nn.Identity()
321
+ # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
322
+ if use_rotary:
323
+ self.rotary = RotaryEmbedding(dim=latent_dim)
324
+ else:
325
+ self.abs_pos_encoding = PositionalEncoding(
326
+ latent_dim, dropout, batch_first=True
327
+ )
328
+
329
+ # time embedding processing
330
+ self.time_mlp = nn.Sequential(
331
+ SinusoidalPosEmb(latent_dim), # learned?
332
+ nn.Linear(latent_dim, latent_dim * 4),
333
+ nn.Mish(),
334
+ )
335
+
336
+ self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
337
+
338
+ self.to_time_tokens = nn.Sequential(
339
+ nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
340
+ Rearrange("b (r d) -> b r d", r=2),
341
+ )
342
+
343
+ # null embeddings for guidance dropout
344
+ self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
345
+ self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
346
+
347
+ self.norm_cond = nn.LayerNorm(latent_dim)
348
+
349
+ # input projection
350
+ self.input_projection = nn.Linear(nfeats, latent_dim)
351
+ self.cond_encoder = nn.Sequential()
352
+ for _ in range(2):
353
+ self.cond_encoder.append(
354
+ TransformerEncoderLayer(
355
+ d_model=latent_dim,
356
+ nhead=num_heads,
357
+ dim_feedforward=ff_size,
358
+ dropout=dropout,
359
+ activation=activation,
360
+ batch_first=True,
361
+ rotary=self.rotary,
362
+ )
363
+ )
364
+ # conditional projection
365
+ self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) # debug cond_feature_dim
366
+ self.non_attn_cond_projection = nn.Sequential(
367
+ nn.LayerNorm(latent_dim),
368
+ nn.Linear(latent_dim, latent_dim),
369
+ nn.SiLU(),
370
+ nn.Linear(latent_dim, latent_dim),
371
+ )
372
+ # decoder
373
+ decoderstack = nn.ModuleList([])
374
+ for _ in range(num_layers):
375
+ decoderstack.append(
376
+ FiLMTransformerDecoderLayer(
377
+ latent_dim,
378
+ num_heads,
379
+ dim_feedforward=ff_size,
380
+ dropout=dropout,
381
+ activation=activation,
382
+ batch_first=True,
383
+ rotary=self.rotary,
384
+ )
385
+ )
386
+
387
+ self.seqTransDecoder = DecoderLayerStack(decoderstack)
388
+
389
+ self.final_layer = nn.Linear(latent_dim, output_feats)
390
+
391
+ def guided_forward(self, x, cond_embed, times, guidance_weight):
392
+ unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
393
+ conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
394
+
395
+ return unc + (conditioned - unc) * guidance_weight
396
+
397
+ def forward(
398
+ self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
399
+ ):
400
+ batch_size, device = x.shape[0], x.device
401
+
402
+ # project to latent space
403
+ x = self.input_projection(x)
404
+ # add the positional embeddings of the input sequence to provide temporal information
405
+ x = self.abs_pos_encoding(x)
406
+
407
+ # create music conditional embedding with conditional dropout
408
+ keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
409
+ keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
410
+ keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
411
+
412
+ cond_tokens = self.cond_projection(cond_embed)
413
+ # encode tokens
414
+ cond_tokens = self.abs_pos_encoding(cond_tokens)
415
+ cond_tokens = self.cond_encoder(cond_tokens)
416
+
417
+ null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
418
+ cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
419
+
420
+ mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
421
+ cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
422
+
423
+ # create the diffusion timestep embedding, add the extra music projection
424
+ t_hidden = self.time_mlp(times)
425
+
426
+ # project to attention and FiLM conditioning
427
+ t = self.to_time_cond(t_hidden)
428
+ t_tokens = self.to_time_tokens(t_hidden)
429
+
430
+ # FiLM conditioning
431
+ null_cond_hidden = self.null_cond_hidden.to(t.dtype)
432
+ cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
433
+ t += cond_hidden
434
+
435
+ # cross-attention conditioning
436
+ c = torch.cat((cond_tokens, t_tokens), dim=-2)
437
+ cond_tokens = self.norm_cond(c)
438
+
439
+ # Pass through the transformer decoder
440
+ # attending to the conditional embedding
441
+ output = self.seqTransDecoder(x, cond_tokens, t)
442
+
443
+ output = self.final_layer(output)
444
+ return output
model/rotary_embedding_torch.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from math import log, pi
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import einsum, nn
7
+
8
+ # helper functions
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def broadcat(tensors, dim=-1):
16
+ num_tensors = len(tensors)
17
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
18
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
19
+ shape_len = list(shape_lens)[0]
20
+
21
+ dim = (dim + shape_len) if dim < 0 else dim
22
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
23
+
24
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
25
+ assert all(
26
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
27
+ ), "invalid dimensions for broadcastable concatentation"
28
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
29
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
30
+ expanded_dims.insert(dim, (dim, dims[dim]))
31
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
32
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
33
+ return torch.cat(tensors, dim=dim)
34
+
35
+
36
+ # rotary embedding helper functions
37
+
38
+
39
+ def rotate_half(x):
40
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, "... d r -> ... (d r)")
44
+
45
+
46
+ def apply_rotary_emb(freqs, t, start_index=0):
47
+ freqs = freqs.to(t)
48
+ rot_dim = freqs.shape[-1]
49
+ end_index = start_index + rot_dim
50
+ assert (
51
+ rot_dim <= t.shape[-1]
52
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
53
+ t_left, t, t_right = (
54
+ t[..., :start_index],
55
+ t[..., start_index:end_index],
56
+ t[..., end_index:],
57
+ )
58
+ t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
59
+ return torch.cat((t_left, t, t_right), dim=-1)
60
+
61
+
62
+ # learned rotation helpers
63
+
64
+
65
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
66
+ if exists(freq_ranges):
67
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
68
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
69
+
70
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
71
+ return apply_rotary_emb(rotations, t, start_index=start_index)
72
+
73
+
74
+ # classes
75
+
76
+
77
+ class RotaryEmbedding(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ custom_freqs=None,
82
+ freqs_for="lang",
83
+ theta=10000,
84
+ max_freq=10,
85
+ num_freqs=1,
86
+ learned_freq=False,
87
+ ):
88
+ super().__init__()
89
+ if exists(custom_freqs):
90
+ freqs = custom_freqs
91
+ elif freqs_for == "lang":
92
+ freqs = 1.0 / (
93
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
94
+ )
95
+ elif freqs_for == "pixel":
96
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
97
+ elif freqs_for == "constant":
98
+ freqs = torch.ones(num_freqs).float()
99
+ else:
100
+ raise ValueError(f"unknown modality {freqs_for}")
101
+
102
+ self.cache = dict()
103
+
104
+ if learned_freq:
105
+ self.freqs = nn.Parameter(freqs)
106
+ else:
107
+ self.register_buffer("freqs", freqs)
108
+
109
+ def rotate_queries_or_keys(self, t, seq_dim=-2):
110
+ device = t.device
111
+ seq_len = t.shape[seq_dim]
112
+ freqs = self.forward(
113
+ lambda: torch.arange(seq_len, device=device), cache_key=seq_len
114
+ )
115
+ return apply_rotary_emb(freqs, t)
116
+
117
+ def forward(self, t, cache_key=None):
118
+ if exists(cache_key) and cache_key in self.cache:
119
+ return self.cache[cache_key]
120
+
121
+ if isfunction(t):
122
+ t = t()
123
+
124
+ freqs = self.freqs
125
+
126
+ freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
127
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
128
+
129
+ if exists(cache_key):
130
+ self.cache[cache_key] = freqs
131
+
132
+ return freqs
model/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from einops import rearrange, reduce, repeat
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+
10
+ # absolute positional embedding used for vanilla transformer sequential data
11
+ class PositionalEncoding(nn.Module):
12
+ def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
13
+ super().__init__()
14
+ self.batch_first = batch_first
15
+
16
+ self.dropout = nn.Dropout(p=dropout)
17
+
18
+ pe = torch.zeros(max_len, d_model)
19
+ position = torch.arange(0, max_len).unsqueeze(1)
20
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
21
+ pe[:, 0::2] = torch.sin(position * div_term)
22
+ pe[:, 1::2] = torch.cos(position * div_term)
23
+ pe = pe.unsqueeze(0).transpose(0, 1)
24
+
25
+ self.register_buffer("pe", pe)
26
+
27
+ def forward(self, x):
28
+ if self.batch_first:
29
+ x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
30
+ else:
31
+ x = x + self.pe[: x.shape[0], :]
32
+ return self.dropout(x)
33
+
34
+
35
+ # very similar positional embedding used for diffusion timesteps
36
+ class SinusoidalPosEmb(nn.Module):
37
+ def __init__(self, dim):
38
+ super().__init__()
39
+ self.dim = dim
40
+
41
+ def forward(self, x):
42
+ device = x.device
43
+ half_dim = self.dim // 2
44
+ emb = math.log(10000) / (half_dim - 1)
45
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
46
+ emb = x[:, None] * emb[None, :]
47
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
48
+ return emb
49
+
50
+
51
+ # dropout mask
52
+ def prob_mask_like(shape, prob, device):
53
+ if prob == 1:
54
+ return torch.ones(shape, device=device, dtype=torch.bool)
55
+ elif prob == 0:
56
+ return torch.zeros(shape, device=device, dtype=torch.bool)
57
+ else:
58
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
59
+
60
+
61
+ def extract(a, t, x_shape):
62
+ b, *_ = t.shape
63
+ out = a.gather(-1, t)
64
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
65
+
66
+
67
+ def make_beta_schedule(
68
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
69
+ ):
70
+ if schedule == "linear":
71
+ betas = (
72
+ torch.linspace(
73
+ linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
74
+ )
75
+ ** 2
76
+ )
77
+
78
+ elif schedule == "cosine":
79
+ timesteps = (
80
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
81
+ )
82
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
83
+ alphas = torch.cos(alphas).pow(2)
84
+ alphas = alphas / alphas[0]
85
+ betas = 1 - alphas[1:] / alphas[:-1]
86
+ betas = np.clip(betas, a_min=0, a_max=0.999)
87
+
88
+ elif schedule == "sqrt_linear":
89
+ betas = torch.linspace(
90
+ linear_start, linear_end, n_timestep, dtype=torch.float64
91
+ )
92
+ elif schedule == "sqrt":
93
+ betas = (
94
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
95
+ ** 0.5
96
+ )
97
+ else:
98
+ raise ValueError(f"schedule '{schedule}' unknown.")
99
+ return betas.numpy()
render.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ import os
6
+ # os.environ["PYOPENGL_PLATFORM"] = "osmesa" # Not available on macOS
7
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
8
+ from tqdm import tqdm
9
+ from smplx import SMPL, SMPLX, SMPLH
10
+ import pyrender
11
+ import trimesh
12
+ import subprocess
13
+ import pickle
14
+ from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle,
15
+ matrix_to_quaternion, matrix_to_rotation_6d,
16
+ quaternion_to_matrix, rotation_6d_to_matrix)
17
+
18
+ import sys
19
+ sys.path.append('.')
20
+ import argparse
21
+
22
+
23
+ def quat_to_6v(q):
24
+ assert q.shape[-1] == 4
25
+ mat = quaternion_to_matrix(q)
26
+ mat = matrix_to_rotation_6d(mat)
27
+ return mat
28
+
29
+
30
+ def quat_from_6v(q):
31
+ assert q.shape[-1] == 6
32
+ mat = rotation_6d_to_matrix(q)
33
+ quat = matrix_to_quaternion(mat)
34
+ return quat
35
+
36
+
37
+ def ax_to_6v(q):
38
+ assert q.shape[-1] == 3
39
+ mat = axis_angle_to_matrix(q)
40
+ mat = matrix_to_rotation_6d(mat)
41
+ return mat
42
+
43
+
44
+ def ax_from_6v(q):
45
+ assert q.shape[-1] == 6
46
+ mat = rotation_6d_to_matrix(q)
47
+ ax = matrix_to_axis_angle(mat)
48
+ return ax
49
+
50
+ class MovieMaker():
51
+ def __init__(self, save_path) -> None:
52
+ self.mag = 2
53
+ self.eyes = np.array([[3,-3,2], [0,0,-2], [0,0,4], [-8,-8,1], [0,-2,4], [0,2,4]])
54
+ self.centers = np.array([[0,0,0],[0,0,0],[0,0.5,0],[0,0,-1], [0,0.5,0], [0,0.5,0]])
55
+ self.ups = np.array([[0,0,1],[0,1,0],[0,1,0],[0,0,-1], [0,1,0], [0,1,0]])
56
+ self.save_path = save_path
57
+ self.fps = args.fps
58
+ self.img_size = (1200,1200)
59
+
60
+
61
+ # SMPLH_path = "assets/smpl_model/smplh/SMPLH_MALE.pkl"
62
+ # SMPL_path = "assets/smpl_model/smpl/SMPL_MALE.pkl"
63
+ SMPLX_path = "assets/smpl_model/smplx/SMPLX_NEUTRAL.npz"
64
+ trimesh_path = 'assets/NORMAL_new.obj'
65
+
66
+
67
+ # self.smplh = SMPLH(SMPLH_path, use_pca=False, flat_hand_mean=True)
68
+ # self.smplh.to(f'cuda:{args.gpu}').eval()
69
+
70
+ # self.smpl = SMPL(SMPL_path)
71
+ # self.smpl.to(f'cuda:{args.gpu}').eval()
72
+
73
+ self.smplx = SMPLX(SMPLX_path, use_pca=False, flat_hand_mean=True).eval()
74
+ _device = "mps" if torch.backends.mps.is_available() else "cpu"
75
+ self.smplx.to(_device).eval()
76
+
77
+ self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0])
78
+ camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
79
+ camera_pose = look_at(self.eyes[5], self.centers[5], self.ups[5]) # 2
80
+ self.scene.add(camera, pose=camera_pose)
81
+ light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)
82
+ self.scene.add(light, pose=camera_pose)
83
+ self.r = pyrender.OffscreenRenderer(self.img_size[0], self.img_size[1])
84
+
85
+
86
+ # self.mesh = trimesh.load(trimesh_path)
87
+ # floor_mesh = pyrender.Mesh.from_trimesh(self.mesh)
88
+ # floor_node = self.scene.add(floor_mesh)
89
+
90
+
91
+ def save_video(self, save_path, color_list):
92
+ # save_path = os.path.join(save_path,'move.mp4')
93
+ f = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
94
+ videowriter = cv2.VideoWriter(save_path,f,self.fps,self.img_size)
95
+ for i in range(len(color_list)):
96
+ videowriter.write(color_list[i][:,:,::-1])
97
+ videowriter.release()
98
+
99
+ def get_imgs(self, motion):
100
+ meshes = self.motion2mesh(motion)
101
+ imgs = self.render_imgs(meshes)
102
+ return np.concatenate(imgs, axis=1)
103
+
104
+ def motion2mesh(self, motion):
105
+ if args.mode == "smpl":
106
+ output = self.smpl.forward(
107
+ betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
108
+ transl = motion[:,:3],
109
+ global_orient = motion[:,3:6],
110
+ body_pose = torch.cat([motion[:,6:69], motion[:,69:72], motion[:,114:117]], dim=1)
111
+ )
112
+ elif args.mode == "smplh":
113
+ output = self.smplh.forward(
114
+ betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
115
+ # transl = motion[:,:3],
116
+ transl = torch.tensor([[0,0,-1]]).expand(motion.shape[0],-1).to(motion.device) ,
117
+ global_orient = motion[:,3:6],
118
+ body_pose = motion[:,6:69],
119
+ left_hand_pose = motion[:,69:114],
120
+ right_hand_pose = motion[:,114:159],
121
+ )
122
+ elif args.mode == "smplx":
123
+ output = self.smplx.forward(
124
+ betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
125
+ # transl = motion[:,:3],
126
+ transl = motion[:,:3],
127
+ global_orient = motion[:,3:6],
128
+ body_pose = motion[:,6:69],
129
+ jaw_pose = torch.zeros([motion.shape[0], 3]).to(motion),
130
+ leye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
131
+ reye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
132
+ left_hand_pose = motion[:,69:69+45],
133
+ right_hand_pose = motion[:,69+45:],
134
+ expression= torch.zeros([motion.shape[0], 10]).to(motion),
135
+ )
136
+
137
+ meshes = []
138
+ for i in range(output.vertices.shape[0]):
139
+ if args.mode == 'smplh':
140
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
141
+ elif args.mode == 'smplx':
142
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
143
+ elif args.mode == 'smpl':
144
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
145
+ # mesh.export(os.path.join(self.save_path, f'{i}.obj'))
146
+ meshes.append(mesh)
147
+
148
+ return meshes
149
+
150
+
151
+ def render_multi_view(self, meshes, music_file, tab='', eyes=None, centers=None, ups=None, views=1):
152
+ if eyes and centers and ups:
153
+ assert eyes.shape == centers.shape == ups.shape
154
+ else:
155
+ eyes = self.eyes
156
+ centers = self.centers
157
+ ups = self.ups
158
+
159
+ for i in range(views):
160
+ color_list = self.render_single_view(meshes, eyes[1], centers[1], ups[1])
161
+ movie_file = os.path.join(self.save_path, tab + '-' + str(i) + '.mp4')
162
+ output_file = os.path.join(self.save_path, tab + '-' + str(i) + '-music.mp4')
163
+ self.save_video(movie_file, color_list)
164
+ if music_file is not None:
165
+ subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
166
+ else:
167
+ subprocess.run(['ffmpeg','-i',movie_file,output_file])
168
+ # if music_file is not None:
169
+ # subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
170
+ # else:
171
+ # subprocess.run(['ffmpeg','-i',movie_file,output_file])
172
+ os.remove(movie_file)
173
+
174
+
175
+
176
+
177
+ def render_single_view(self, meshes):
178
+ num = len(meshes)
179
+ color_list = []
180
+ for i in tqdm(range(num)):
181
+ mesh_nodes = []
182
+ for mesh in meshes[i]:
183
+ render_mesh = pyrender.Mesh.from_trimesh(mesh)
184
+ mesh_node = self.scene.add(render_mesh)
185
+ mesh_nodes.append(mesh_node)
186
+ color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
187
+ color = color.copy()
188
+ color_list.append(color)
189
+ for mesh_node in mesh_nodes:
190
+ self.scene.remove_node(mesh_node)
191
+ return color_list
192
+
193
+ def render_imgs(self, meshes):
194
+ colors = []
195
+ for mesh in meshes:
196
+ render_mesh = pyrender.Mesh.from_trimesh(mesh)
197
+ mesh_node = self.scene.add(render_mesh)
198
+ color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
199
+ colors.append(color)
200
+ self.scene.remove_node(mesh_node)
201
+
202
+
203
+ return colors
204
+ # cv2.imwrite(os.path.join(self.save_path, 'test.jpg'), color[:,:,::-1])
205
+
206
+ def run(self, seq_rot, music_file=None, tab='', save_pt=False):
207
+ if isinstance(seq_rot, np.ndarray):
208
+ seq_rot = torch.tensor(seq_rot, dtype=torch.float32, device="mps" if torch.backends.mps.is_available() else "cpu")
209
+
210
+ if save_pt:
211
+ torch.save(seq_rot.detach().cpu(), os.path.join(self.save_path, tab +'_pose.pt'))
212
+
213
+ B, D = seq_rot.shape
214
+ if args.mode == "smpl":
215
+ print("using smpl!!!")
216
+ output = self.smpl.forward(
217
+ betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
218
+ transl = seq_rot[:,:3],
219
+ global_orient = seq_rot[:,3:6],
220
+ body_pose = torch.cat([seq_rot[:,6:69], seq_rot[:,69:72], seq_rot[:,114:117]], dim=1)
221
+ )
222
+
223
+ elif args.mode == "smplh":
224
+ print("using smplh!!!")
225
+ output = self.smplh.forward(
226
+ betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
227
+ transl = seq_rot[:,:3],
228
+ global_orient = seq_rot[:,3:6],
229
+ body_pose = seq_rot[:,6:69],
230
+ left_hand_pose = seq_rot[:,69:114], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), # seq_rot[:,69:114],
231
+ right_hand_pose = seq_rot[:,114:], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), #
232
+ expression = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
233
+ )
234
+
235
+ elif args.mode == "smplx":
236
+ output = self.smplx.forward(
237
+ betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
238
+ # transl = motion[:,:3],
239
+ transl = seq_rot[:,:3],
240
+ global_orient = seq_rot[:,3:6],
241
+ body_pose = seq_rot[:,6:69],
242
+ jaw_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
243
+ leye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
244
+ reye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
245
+ left_hand_pose = seq_rot[:,69:69+45],
246
+ right_hand_pose = seq_rot[:,69+45:],
247
+ expression= torch.zeros([seq_rot.shape[0], 10]).to(seq_rot),
248
+ )
249
+
250
+ N, V, DD = output.vertices.shape # 150, 6890, 3
251
+ vertices = output.vertices.reshape((B, -1, V, DD)) # # 150, 1, 6890, 3
252
+
253
+ meshes = []
254
+ for i in range(B):
255
+ # if int(i) > 20:
256
+ # break
257
+ view = []
258
+ for v in vertices[i]:
259
+ # vertices[:,2] *= -1
260
+ if args.mode == 'smplh':
261
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
262
+ elif args.mode == 'smplx':
263
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
264
+ elif args.mode == 'smpl':
265
+ mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
266
+ # mesh.export(os.path.join(self.save_path, 'test.obj'))
267
+ view.append(mesh)
268
+ meshes.append(view)
269
+
270
+ color_list = self.render_single_view(meshes)
271
+ movie_file = os.path.join(self.save_path, tab + 'tmp.mp4')
272
+ output_file = os.path.join(self.save_path, tab + 'z.mp4')
273
+ self.save_video(movie_file, color_list)
274
+ if music_file is not None:
275
+ subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
276
+ else:
277
+ subprocess.run(['ffmpeg','-i',movie_file,output_file])
278
+ # if music_file is not None:
279
+ # subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
280
+ # else:
281
+ # subprocess.run(['ffmpeg','-i',movie_file,output_file])
282
+ os.remove(movie_file)
283
+
284
+
285
+
286
+
287
+ def look_at(eye, center, up):
288
+ front = eye - center
289
+ front = front / np.linalg.norm(front)
290
+ right = np.cross(up, front)
291
+ right = right/ np.linalg.norm(right)
292
+ up_new = np.cross(front, right)
293
+ camera_pose = np.eye(4)
294
+ camera_pose[:3,:3] = np.stack([right, up_new, front]).transpose()
295
+ camera_pose[:3,3] = eye
296
+ return camera_pose
297
+
298
+
299
+ def motion_data_load_process(motionfile):
300
+ if motionfile.split(".")[-1] == "pkl":
301
+ pkl_data = pickle.load(open(motionfile, "rb"))
302
+ smpl_poses = pkl_data["smpl_poses"]
303
+ modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1)
304
+ if modata.shape[1] == 69:
305
+ hand_zeros = np.zeros([modata.shape[0], 90], dtype=np.float32)
306
+ modata = np.concatenate((modata, hand_zeros), axis=1)
307
+ assert modata.shape[1] == 159
308
+ modata[:, 1] = modata[:, 1] + 1.3
309
+ return modata
310
+ elif motionfile.split(".")[-1] == "npy":
311
+ modata = np.load(motionfile)
312
+ print("modata.shape", modata.shape)
313
+ if modata.shape[-1] == 315: # first 3-dim is root translation
314
+ print("modata.shape is:", modata.shape)
315
+ rot6d = torch.from_numpy(modata[:,3:])
316
+ T,C = rot6d.shape
317
+ rot6d = rot6d.reshape(-1,6)
318
+ axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
319
+ modata = np.concatenate((modata[:,:3], axis), axis=1)
320
+ print("modata.shape is:", modata.shape)
321
+ elif modata.shape[-1] == 319:
322
+ print("modata.shape is:", modata.shape)
323
+ modata = modata[:,4:]
324
+ rot6d = torch.from_numpy(modata[:,3:])
325
+ T,C = rot6d.shape
326
+ rot6d = rot6d.reshape(-1,6)
327
+ axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
328
+ modata = np.concatenate((modata[:,:3], axis), axis=1)
329
+ print("modata.shape is:", modata.shape)
330
+ elif modata.shape[-1] == 168:
331
+ modata = np.concatenate( [modata[:,:21*3+1], modata[:,25*3:]] , axis=1)
332
+ elif modata.shape[-1] == 159:
333
+ print("modata.shape is:", modata.shape)
334
+ print("modata.shape is:", modata.shape)
335
+ elif modata.shape[-1] == 135:
336
+ print("modata.shape is:", modata.shape)
337
+ if len(modata.shape) == 3 and modata.shape[0] ==1:
338
+ modata = modata.squeeze(0)
339
+ rot6d = torch.from_numpy(modata[:,3:])
340
+ T,C = rot6d.shape
341
+ rot6d = rot6d.reshape(-1,6)
342
+ axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
343
+ hand_zeros = torch.zeros([T, 90]).to(rot6d).detach().cpu().numpy()
344
+ modata = np.concatenate((modata[:,:3], axis, hand_zeros), axis=1)
345
+ print("modata.shape is:", modata.shape)
346
+ else:
347
+ raise("shape error!")
348
+
349
+ modata[:, 1] = modata[:, 1] + 1.3
350
+ return modata
351
+
352
+ if __name__ == '__main__':
353
+ parser = argparse.ArgumentParser()
354
+ parser.add_argument("--gpu", type=str, default="2")
355
+ parser.add_argument("--modir", type=str, default="")
356
+ parser.add_argument("--mode", type=str, default="smplx", choices=['smpl','smplh','smplx'])
357
+ parser.add_argument("--fps", type=int, default=30)
358
+ parser.add_argument("--save_path", type=str, default=None)
359
+ args = parser.parse_args()
360
+ print(args.gpu)
361
+
362
+
363
+ motion_dir = args.modir
364
+ if args.save_path is not None:
365
+ save_path = args.save_path
366
+ if not os.path.exists(save_path):
367
+ os.makedirs(save_path)
368
+ else:
369
+ save_path = os.path.join(motion_dir, 'video')
370
+ os.makedirs(save_path, exist_ok=True)
371
+
372
+
373
+ music_dir = "experiments/DanceDiffuse_module/debug--0517_Norm_512len_315_transloss/val1640/samples_2023-05-17-20-54-05"
374
+ for file in os.listdir(motion_dir):
375
+ if file[-3:] in ["npy", "pkl"]:
376
+
377
+ # if there have exist rendered video, continue
378
+ flag = False
379
+ for exists_file in os.listdir(save_path):
380
+ if file[:-4] in exists_file:
381
+ flag = True
382
+ break
383
+ else:
384
+ flag = False
385
+ if flag:
386
+ print("exist", file)
387
+ continue
388
+
389
+ print(file)
390
+ motion_file = os.path.join(motion_dir, file)
391
+ visualizer = MovieMaker(save_path=save_path)
392
+ modata = motion_data_load_process(motion_file)
393
+ visualizer.run(modata, tab=os.path.basename(motion_file).split(".")[0], music_file=None)
394
+
395
+ print('done')
smplx_neu_J_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa80c6e1b28e9d43470f250a2d168bacf2cb0cc5e8688ff2e48e71f4db13ba6c
3
+ size 788
teaser/teaser.png ADDED

Git LFS Details

  • SHA256: 797034e986aad1b6cd47f78b12b2406ec779315eb205f4211d07c0623f198e5f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
test.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from functools import cmp_to_key
4
+ from pathlib import Path
5
+ import sys
6
+ from tempfile import TemporaryDirectory
7
+ import random
8
+
9
+ # import jukemirlib
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+ import librosa
14
+ import librosa as lr
15
+ import soundfile as sf
16
+
17
+ from args import FineDance_parse_test_opt
18
+ from train_seq import EDGE
19
+ # from data.audio_extraction.jukebox_features import extract as juke_extract
20
+
21
+ def slice_audio(audio_file, stride, length, out_dir):
22
+ # stride, length in seconds
23
+ audio, sr = lr.load(audio_file, sr=None)
24
+ file_name = os.path.splitext(os.path.basename(audio_file))[0]
25
+ start_idx = 0
26
+ idx = 0
27
+ window = int(length * sr)
28
+ stride_step = int(stride * sr)
29
+ while start_idx <= len(audio) - window:
30
+ audio_slice = audio[start_idx : start_idx + window]
31
+ sf.write(f"{out_dir}/{file_name}_slice{idx}.wav", audio_slice, sr)
32
+ start_idx += stride_step
33
+ idx += 1
34
+ return idx
35
+
36
+ def extract(fpath):
37
+ FPS = 30
38
+ HOP_LENGTH = 512
39
+ SR = FPS * HOP_LENGTH
40
+ EPS = 1e-6
41
+
42
+ data, _ = librosa.load(fpath, sr=SR)
43
+ envelope = librosa.onset.onset_strength(y=data, sr=SR) # (seq_len,)
44
+ mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T # (seq_len, 20)
45
+ chroma = librosa.feature.chroma_cens(
46
+ y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12
47
+ ).T # (seq_len, 12)
48
+
49
+ peak_idxs = librosa.onset.onset_detect(
50
+ onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
51
+ )
52
+ peak_onehot = np.zeros_like(envelope, dtype=np.float32)
53
+ peak_onehot[peak_idxs] = 1.0 # (seq_len,)
54
+
55
+ start_bpm = lr.beat.tempo(y=lr.load(fpath)[0])[0]
56
+
57
+ tempo, beat_idxs = librosa.beat.beat_track(
58
+ onset_envelope=envelope,
59
+ sr=SR,
60
+ hop_length=HOP_LENGTH,
61
+ start_bpm=start_bpm,
62
+ tightness=100,
63
+ )
64
+ beat_onehot = np.zeros_like(envelope, dtype=np.float32)
65
+ beat_onehot[beat_idxs] = 1.0 # (seq_len,)
66
+
67
+ audio_feature = np.concatenate(
68
+ [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
69
+ axis=-1,
70
+ )
71
+
72
+ # chop to ensure exact shape
73
+ audio_feature = audio_feature[:4 * FPS]
74
+ return audio_feature
75
+
76
+ # sort filenames that look like songname_slice{number}.ext
77
+ key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
78
+ # test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
79
+ test_list = ["063", "144"]
80
+
81
+ def stringintcmp_(a, b):
82
+ aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
83
+ ka, kb = key_func(a), key_func(b)
84
+ if aa < bb:
85
+ return -1
86
+ if aa > bb:
87
+ return 1
88
+ if ka < kb:
89
+ return -1
90
+ if ka > kb:
91
+ return 1
92
+ return 0
93
+
94
+ stringintkey = cmp_to_key(stringintcmp_)
95
+ stride_ = 60/30
96
+
97
+ def test(opt):
98
+ feature_func = extract
99
+ sample_length = opt.out_length
100
+ sample_size = int(sample_length / stride_) - 1
101
+ temp_dir_list = []
102
+ all_cond = []
103
+ all_filenames = []
104
+ if opt.use_cached_features: # default is false
105
+ print("Using precomputed features")
106
+ # all subdirectories
107
+ dir_list = glob.glob(os.path.join(opt.feature_cache_dir, "*/"))
108
+ for dir in dir_list:
109
+ file_list = sorted(glob.glob(f"{dir}/*.wav"), key=stringintkey)
110
+ juke_file_list = sorted(glob.glob(f"{dir}/*.npy"), key=stringintkey)
111
+ assert len(file_list) == len(juke_file_list)
112
+
113
+ # random chunk after sanity check
114
+ rand_idx = random.randint(0, len(file_list) - sample_size)
115
+ file_list = file_list[rand_idx : rand_idx + sample_size]
116
+ juke_file_list = juke_file_list[rand_idx : rand_idx + sample_size]
117
+ cond_list = [np.load(x) for x in juke_file_list]
118
+ all_filenames.append(file_list)
119
+ all_cond.append(torch.from_numpy(np.array(cond_list)))
120
+ else:
121
+ print("Computing features for input music")
122
+ for wav_file in glob.glob(os.path.join(opt.music_dir, "*.wav")):
123
+ songname = os.path.splitext(os.path.basename(wav_file))[0]
124
+ # create temp folder (or use the cache folder if specified)
125
+ if True: # songname in test_list:
126
+ if opt.cache_features:
127
+ save_dir = os.path.join(opt.feature_cache_dir, songname)
128
+ Path(save_dir).mkdir(parents=True, exist_ok=True)
129
+ dirname = save_dir
130
+ else:
131
+ temp_dir = TemporaryDirectory()
132
+ print("temp_dir is", temp_dir)
133
+ temp_dir_list.append(temp_dir)
134
+ dirname = temp_dir.name
135
+ # slice the audio file
136
+ print(f"Slicing {wav_file}")
137
+ slice_audio(wav_file, 60/30, 120/30, dirname)
138
+ file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=stringintkey)
139
+ # randomly sample a chunk of length at most sample_size
140
+ rand_idx = random.randint(0, len(file_list) - sample_size)
141
+ cond_list = []
142
+ # generate juke representations
143
+ print(f"Computing features for {wav_file}")
144
+ for idx, file in enumerate(tqdm(file_list)):
145
+ # if not caching then only calculate for the interested range
146
+ if (not opt.cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
147
+ continue
148
+ # audio = jukemirlib.load_audio(file)
149
+ # reps = jukemirlib.extract(
150
+ # audio, layers=[66], downsample_target_rate=30
151
+ # )[66]
152
+ reps = feature_func(file)[:opt.full_seq_len]
153
+ # save reps
154
+ if opt.cache_features:
155
+ featurename = os.path.splitext(file)[0] + ".npy"
156
+ np.save(featurename, reps)
157
+ # if in the random range, put it into the list of reps we want
158
+ # to actually use for generation
159
+ if rand_idx <= idx < rand_idx + sample_size:
160
+ cond_list.append(reps)
161
+ cond_list = torch.from_numpy(np.array(cond_list))
162
+ all_cond.append(cond_list)
163
+ all_filenames.append(file_list[rand_idx : rand_idx + sample_size])
164
+
165
+ model = EDGE(opt, opt.feature_type, opt.checkpoint)
166
+ model.eval()
167
+
168
+ # directory for optionally saving the dances for eval
169
+ fk_out = None
170
+ if opt.save_motions:
171
+ fk_out = opt.motion_save_dir
172
+
173
+ print("Generating dances")
174
+ for i in range(len(all_cond)):
175
+ data_tuple = None, all_cond[i], all_filenames[i]
176
+ model.render_sample(
177
+ data_tuple, "test", opt.render_dir, render_count=-1, fk_out=fk_out, mode="long", render=not opt.no_render
178
+ )
179
+ print("Done")
180
+ if torch.cuda.is_available():
181
+ torch.cuda.empty_cache()
182
+ for temp_dir in temp_dir_list:
183
+ temp_dir.cleanup()
184
+
185
+ if __name__ == "__main__":
186
+ opt = FineDance_parse_test_opt()
187
+ test(opt)
train_seq.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import os
3
+ from zlib import Z_FULL_FLUSH
4
+ # os.environ["WANDB_API_KEY"] = "your WANDB_API_KEY" #
5
+ # os.environ["WANDB_MODE"] = "online"
6
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
7
+ import pickle
8
+ from functools import partial
9
+ from pathlib import Path
10
+ from args import FineDance_parse_train_opt, save_arguments_to_yaml
11
+ import sys
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import wandb
16
+ from accelerate import Accelerator, DistributedDataParallelKwargs
17
+ from accelerate.state import AcceleratorState
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+
21
+ from dataset.FineDance_dataset import FineDance_Smpl
22
+ from dataset.preprocess import increment_path
23
+ from dataset.preprocess import My_Normalizer as Normalizer # do not use Normalizer
24
+ from model.adan import Adan
25
+ from model.diffusion import GaussianDiffusion
26
+ from model.model import DanceDecoder, SeqModel
27
+ from vis import SMPLX_Skeleton, SMPLSkeleton
28
+
29
+
30
+ def wrap(x):
31
+ return {f"module.{key}": value for key, value in x.items()}
32
+
33
+
34
+ def maybe_wrap(x, num):
35
+ return x if num == 1 else wrap(x)
36
+
37
+
38
+ class EDGE:
39
+ def __init__(
40
+ self,
41
+ opt,
42
+ feature_type,
43
+ checkpoint_path="",
44
+ normalizer=None,
45
+ EMA=True,
46
+ learning_rate=4e-4,
47
+ weight_decay=0.02,
48
+ ):
49
+ self.opt = opt
50
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
+ self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
52
+ state = AcceleratorState()
53
+ num_processes = state.num_processes
54
+
55
+ self.repr_dim = repr_dim = opt.nfeats
56
+ feature_dim = 35
57
+
58
+ self.horizon = horizon = opt.full_seq_len
59
+
60
+ self.accelerator.wait_for_everyone()
61
+
62
+ self.resume_num = 0
63
+ checkpoint = None
64
+ self.normalizer = None
65
+ if checkpoint_path != "":
66
+ checkpoint = torch.load(
67
+ checkpoint_path, map_location=self.accelerator.device
68
+ )
69
+ self.resume_num = int(os.path.basename(checkpoint_path).split("-")[1].split(".")[0]) # int(os.path.basenam
70
+
71
+ model = SeqModel(
72
+ nfeats=repr_dim,
73
+ seq_len=horizon,
74
+ latent_dim=512,
75
+ ff_size=1024,
76
+ num_layers=8,
77
+ num_heads=8,
78
+ dropout=0.1,
79
+ cond_feature_dim=feature_dim,
80
+ activation=F.gelu,
81
+ )
82
+ if opt.nfeats == 139 or opt.nfeats == 135:
83
+ smplx_fk = SMPLSkeleton(device=self.accelerator.device)
84
+ else:
85
+ smplx_fk = SMPLX_Skeleton(device=self.accelerator.device, batch=512000)
86
+ diffusion = GaussianDiffusion(
87
+ model,
88
+ opt,
89
+ horizon,
90
+ repr_dim,
91
+ smplx_model = smplx_fk,
92
+ schedule="cosine",
93
+ n_timestep=1000,
94
+ predict_epsilon=False,
95
+ loss_type="l2",
96
+ use_p2=False,
97
+ cond_drop_prob=0.25,
98
+ guidance_weight=2,
99
+ do_normalize = opt.do_normalize
100
+ )
101
+
102
+ print(
103
+ "Model has {} parameters".format(sum(y.numel() for y in model.parameters()))
104
+ )
105
+
106
+ self.model = self.accelerator.prepare(model)
107
+ self.diffusion = diffusion.to(self.accelerator.device) # 为什么这里不需要prepare
108
+ self.smplx_fk = smplx_fk # to(self.accelerator.device)
109
+ optim = Adan(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
110
+ self.optim = self.accelerator.prepare(optim)
111
+
112
+ if checkpoint_path != "":
113
+ self.model.load_state_dict(
114
+ maybe_wrap(
115
+ checkpoint["ema_state_dict" if EMA else "model_state_dict"],
116
+ num_processes,
117
+ )
118
+ )
119
+
120
+ def eval(self):
121
+ self.diffusion.eval()
122
+
123
+ def train(self):
124
+ self.diffusion.train()
125
+
126
+ def prepare(self, objects):
127
+ return self.accelerator.prepare(*objects)
128
+
129
+ def train_loop(self, opt):
130
+ print("train_dataset = FineDance_Dataset ")
131
+ train_dataset = FineDance_Smpl(
132
+ args=opt, # data/
133
+ istrain=True,
134
+ )
135
+ test_dataset = FineDance_Smpl(
136
+ args=opt,
137
+ istrain=False,
138
+ )
139
+
140
+ num_cpus = multiprocessing.cpu_count()
141
+ print("batchsize=:", opt.batch_size)
142
+ train_data_loader = DataLoader(
143
+ train_dataset,
144
+ batch_size=opt.batch_size,
145
+ shuffle=True,
146
+ num_workers=min(int(num_cpus * 0.5), 40), # num_workers=min(int(num_cpus * 0.75), 32),
147
+ pin_memory=True,
148
+ drop_last=True,
149
+ )
150
+ test_data_loader = DataLoader(
151
+ test_dataset,
152
+ batch_size=opt.batch_size,
153
+ shuffle=True,
154
+ num_workers=2,
155
+ pin_memory=True,
156
+ drop_last=True,
157
+ )
158
+
159
+ train_data_loader = self.accelerator.prepare(train_data_loader)
160
+ # boot up multi-gpu training. test dataloader is only on main process
161
+ load_loop = (
162
+ partial(tqdm, position=1, desc="Batch")
163
+ if self.accelerator.is_main_process
164
+ else lambda x: x
165
+ )
166
+ if self.accelerator.is_main_process:
167
+ save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
168
+ opt.exp_name = save_dir.split("/")[-1]
169
+ wandb.init(project=opt.wandb_pj_name, name=opt.exp_name)
170
+ save_dir = Path(save_dir)
171
+ wdir = save_dir / "weights"
172
+ wdir.mkdir(parents=True, exist_ok=True)
173
+ wandb.save("params.yaml") # 保存wandb配置到文件
174
+ yaml_path = os.path.join(wdir, 'parameters.yaml')
175
+ save_arguments_to_yaml(opt, yaml_path)
176
+
177
+
178
+ self.accelerator.wait_for_everyone()
179
+ for epoch in range(1, opt.epochs + 1):
180
+ print("epoch:", epoch+self.resume_num)
181
+ avg_loss = 0
182
+ avg_vloss = 0
183
+ avg_fkloss = 0
184
+ avg_footloss = 0
185
+
186
+ # train
187
+ self.train()
188
+ for step, (x, cond, filename) in enumerate(
189
+ load_loop(train_data_loader)
190
+ ):
191
+ if opt.nfeats == 139 or opt.nfeats==135:
192
+ x = x[:, :, :139]
193
+
194
+ total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion(
195
+ x, cond, t_override=None
196
+ )
197
+ # print("3")
198
+ self.optim.zero_grad()
199
+ self.accelerator.backward(total_loss)
200
+ self.optim.step()
201
+
202
+ # ema update and train loss update only on main
203
+ if self.accelerator.is_main_process:
204
+ avg_loss += loss.detach().cpu().numpy()
205
+ avg_vloss += v_loss.detach().cpu().numpy()
206
+ avg_fkloss += fk_loss.detach().cpu().numpy()
207
+ avg_footloss += foot_loss.detach().cpu().numpy()
208
+ if step % opt.ema_interval == 0:
209
+ self.diffusion.ema.update_model_average(
210
+ self.diffusion.master_model, self.diffusion.model
211
+ )
212
+
213
+ #-----------------------------------------------------------------------------------------------------------
214
+ # test
215
+ # Save model
216
+
217
+ if ((epoch+self.resume_num) % opt.save_interval) == 0 or epoch<=1:
218
+ # everyone waits here for the val loop to finish ( don't start next train epoch early)
219
+ self.accelerator.wait_for_everyone()
220
+ self.eval() # debug!
221
+ # save only if on main thread
222
+ if self.accelerator.is_main_process:
223
+ # self.eval()
224
+ # log
225
+ avg_loss /= len(train_data_loader)
226
+ avg_vloss /= len(train_data_loader)
227
+ avg_fkloss /= len(train_data_loader)
228
+ avg_footloss /= len(train_data_loader)
229
+ log_dict = {
230
+ "Train Loss": avg_loss,
231
+ "V Loss": avg_vloss,
232
+ "FK Loss": avg_fkloss,
233
+ "Foot Loss": avg_footloss,
234
+ }
235
+
236
+ wandb.log(log_dict)
237
+
238
+ ckpt = {
239
+ "ema_state_dict": self.diffusion.master_model.state_dict(), # 经过accelerate prepare的模型,在保存时需要unwrap,反之不需要
240
+ "model_state_dict": self.accelerator.unwrap_model(
241
+ self.model
242
+ ).state_dict(),
243
+ "optimizer_state_dict": self.optim.state_dict(),
244
+ "normalizer": self.normalizer,
245
+ }
246
+
247
+ torch.save(ckpt, os.path.join(wdir, f"train-{epoch+self.resume_num}.pt"))
248
+ print(f"[MODEL SAVED at Epoch {epoch+self.resume_num}]")
249
+
250
+ # generate a sample
251
+ render_count = 2
252
+ shape = (render_count, self.horizon, self.opt.nfeats)
253
+ print("Generating Sample")
254
+ # draw a music from the test dataset
255
+ (x, cond, filename) = next(iter(test_data_loader))
256
+ # if opt.do_normalize:
257
+ # x = self.normalizer.normalize(x)
258
+
259
+ if opt.nfeats == 139 or opt.nfeats==135:
260
+ x = x[:, :, :139]
261
+
262
+ cond = cond.to(self.accelerator.device)
263
+ # name_iter = name_iter+1
264
+ self.diffusion.render_sample(
265
+ shape,
266
+ cond[:render_count],
267
+ self.normalizer,
268
+ epoch+self.resume_num,
269
+ render_out = os.path.join(opt.render_dir, "train_" + opt.exp_name), # render out
270
+ fk_out = os.path.join(opt.render_dir, "train_" + opt.exp_name),
271
+ name=filename[:render_count],
272
+ # name = str(epoch) + str(name_iter).zfill(3)
273
+ sound=True,
274
+ )
275
+ #-----------------------------------------------------------------------------------------------------------
276
+
277
+
278
+ if self.accelerator.is_main_process:
279
+ wandb.run.finish()
280
+
281
+ def render_sample(
282
+ self, data_tuple, label, render_dir, render_count=-1, mode='normal', fk_out=None, render=True,
283
+ ):
284
+ _, cond, wavname = data_tuple
285
+ assert len(cond.shape) == 3
286
+ if render_count < 0:
287
+ render_count = len(cond)
288
+ shape = (render_count, self.horizon, self.repr_dim)
289
+ cond = cond.to(self.accelerator.device).float()
290
+ self.diffusion.render_sample(
291
+ shape,
292
+ cond[:render_count],
293
+ self.normalizer,
294
+ label,
295
+ render_dir,
296
+ name=wavname[:render_count],
297
+ sound=True,
298
+ mode=mode,
299
+ fk_out=fk_out,
300
+ render=render
301
+ )
302
+
303
+ def train(opt):
304
+ model = EDGE(opt, opt.feature_type)
305
+ model.train_loop(opt)
306
+
307
+ if __name__ == "__main__":
308
+ opt = FineDance_parse_train_opt()
309
+ command = ' '.join(sys.argv)
310
+ if not os.path.exists(os.path.join(opt.project, opt.exp_name)):
311
+ os.makedirs(os.path.join(opt.project, opt.exp_name), exist_ok=False)
312
+ with open(os.path.join(opt.project, opt.exp_name, 'command.txt'), 'w') as f:
313
+ f.write(command)
314
+
315
+ yaml_path = os.path.join(opt.project, opt.exp_name, 'parameters.yaml')
316
+ save_arguments_to_yaml(opt, yaml_path)
317
+
318
+ train(opt)
vis.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import sys
4
+ from tempfile import TemporaryDirectory
5
+
6
+ import librosa as lr
7
+ import matplotlib.animation as animation
8
+ import matplotlib.pyplot as plt
9
+ from mpl_toolkits.mplot3d import axes3d
10
+
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ from matplotlib import cm
15
+ from matplotlib.colors import ListedColormap
16
+ from pytorch3d.transforms import (axis_angle_to_quaternion, quaternion_apply,
17
+ quaternion_multiply)
18
+ from tqdm import tqdm
19
+ from typing import NewType
20
+ Tensor = NewType('Tensor', torch.Tensor)
21
+ import torch.nn.functional as F
22
+ try:
23
+ import pickle5 as pickle
24
+ except ImportError:
25
+ import pickle
26
+
27
+
28
+ smpl_joints = [
29
+ "root", # 0
30
+ "lhip", # 1
31
+ "rhip", # 2
32
+ "belly", # 3
33
+ "lknee", # 4
34
+ "rknee", # 5
35
+ "spine", # 6
36
+ "lankle",# 7
37
+ "rankle",# 8
38
+ "chest", # 9
39
+ "ltoes", # 10
40
+ "rtoes", # 11
41
+ "neck", # 12
42
+ "linshoulder", # 13
43
+ "rinshoulder", # 14
44
+ "head", # 15
45
+ "lshoulder", # 16
46
+ "rshoulder", # 17
47
+ "lelbow", # 18
48
+ "relbow", # 19
49
+ "lwrist", # 20
50
+ "rwrist", # 21
51
+ "lhand", # 22
52
+ "rhand", # 23
53
+ ]
54
+
55
+ smplh_joints = [
56
+ 'pelvis',
57
+ 'left_hip',
58
+ 'right_hip',
59
+ 'spine1',
60
+ 'left_knee',
61
+ 'right_knee',
62
+ 'spine2',
63
+ 'left_ankle',
64
+ 'right_ankle',
65
+ 'spine3',
66
+ 'left_foot',
67
+ 'right_foot',
68
+ 'neck',
69
+ 'left_collar',
70
+ 'right_collar',
71
+ 'head',
72
+ 'left_shoulder',
73
+ 'right_shoulder',
74
+ 'left_elbow',
75
+ 'right_elbow',
76
+ 'left_wrist',
77
+ 'right_wrist',
78
+ 'left_index1',
79
+ 'left_index2',
80
+ 'left_index3',
81
+ 'left_middle1',
82
+ 'left_middle2',
83
+ 'left_middle3',
84
+ 'left_pinky1',
85
+ 'left_pinky2',
86
+ 'left_pinky3',
87
+ 'left_ring1',
88
+ 'left_ring2',
89
+ 'left_ring3',
90
+ 'left_thumb1',
91
+ 'left_thumb2',
92
+ 'left_thumb3',
93
+ 'right_index1',
94
+ 'right_index2',
95
+ 'right_index3',
96
+ 'right_middle1',
97
+ 'right_middle2',
98
+ 'right_middle3',
99
+ 'right_pinky1',
100
+ 'right_pinky2',
101
+ 'right_pinky3',
102
+ 'right_ring1',
103
+ 'right_ring2',
104
+ 'right_ring3',
105
+ 'right_thumb1',
106
+ 'right_thumb2',
107
+ 'right_thumb3'
108
+ ]
109
+
110
+
111
+ smplx_joints = [
112
+ 'pelvis',
113
+ 'left_hip',
114
+ 'right_hip',
115
+ 'spine1',
116
+ 'left_knee',
117
+ 'right_knee',
118
+ 'spine2',
119
+ 'left_ankle',
120
+ 'right_ankle',
121
+ 'spine3',
122
+ 'left_foot',
123
+ 'right_foot',
124
+ 'neck',
125
+ 'left_collar',
126
+ 'right_collar',
127
+ 'head',
128
+ 'left_shoulder',
129
+ 'right_shoulder',
130
+ 'left_elbow',
131
+ 'right_elbow',
132
+ 'left_wrist',
133
+ 'right_wrist',
134
+ 'jaw',
135
+ 'left_eye_smplhf',
136
+ 'right_eye_smplhf',
137
+ 'left_index1',
138
+ 'left_index2',
139
+ 'left_index3',
140
+ 'left_middle1',
141
+ 'left_middle2',
142
+ 'left_middle3',
143
+ 'left_pinky1',
144
+ 'left_pinky2',
145
+ 'left_pinky3',
146
+ 'left_ring1',
147
+ 'left_ring2',
148
+ 'left_ring3',
149
+ 'left_thumb1',
150
+ 'left_thumb2',
151
+ 'left_thumb3',
152
+ 'right_index1',
153
+ 'right_index2',
154
+ 'right_index3',
155
+ 'right_middle1',
156
+ 'right_middle2',
157
+ 'right_middle3',
158
+ 'right_pinky1',
159
+ 'right_pinky2',
160
+ 'right_pinky3',
161
+ 'right_ring1',
162
+ 'right_ring2',
163
+ 'right_ring3',
164
+ 'right_thumb1',
165
+ 'right_thumb2',
166
+ 'right_thumb3'
167
+ ]
168
+
169
+
170
+ smpl_parents = [
171
+ -1,
172
+ 0,
173
+ 0,
174
+ 0,
175
+ 1,
176
+ 2,
177
+ 3,
178
+ 4,
179
+ 5,
180
+ 6,
181
+ 7,
182
+ 8,
183
+ 9,
184
+ 9,
185
+ 9,
186
+ 12,
187
+ 13,
188
+ 14,
189
+ 16,
190
+ 17,
191
+ 18,
192
+ 19,
193
+ 20,
194
+ 21,
195
+ ]
196
+
197
+ smplh_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14,
198
+ 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34,
199
+ 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50]
200
+
201
+ smplx_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52, 53]
202
+
203
+
204
+ smpl_offsets = [
205
+ [0.0, 0.0, 0.0],
206
+ [0.05858135, -0.08228004, -0.01766408],
207
+ [-0.06030973, -0.09051332, -0.01354254],
208
+ [0.00443945, 0.12440352, -0.03838522],
209
+ [0.04345142, -0.38646945, 0.008037],
210
+ [-0.04325663, -0.38368791, -0.00484304],
211
+ [0.00448844, 0.1379564, 0.02682033],
212
+ [-0.01479032, -0.42687458, -0.037428],
213
+ [0.01905555, -0.4200455, -0.03456167],
214
+ [-0.00226458, 0.05603239, 0.00285505],
215
+ [0.04105436, -0.06028581, 0.12204243],
216
+ [-0.03483987, -0.06210566, 0.13032329],
217
+ [-0.0133902, 0.21163553, -0.03346758],
218
+ [0.07170245, 0.11399969, -0.01889817],
219
+ [-0.08295366, 0.11247234, -0.02370739],
220
+ [0.01011321, 0.08893734, 0.05040987],
221
+ [0.12292141, 0.04520509, -0.019046],
222
+ [-0.11322832, 0.04685326, -0.00847207],
223
+ [0.2553319, -0.01564902, -0.02294649],
224
+ [-0.26012748, -0.01436928, -0.03126873],
225
+ [0.26570925, 0.01269811, -0.00737473],
226
+ [-0.26910836, 0.00679372, -0.00602676],
227
+ [0.08669055, -0.01063603, -0.01559429],
228
+ [-0.0887537, -0.00865157, -0.01010708],
229
+ ]
230
+
231
+
232
+ def set_line_data_3d(line, x):
233
+ line.set_data(x[:, :2].T)
234
+ line.set_3d_properties(x[:, 2])
235
+
236
+
237
+ def set_scatter_data_3d(scat, x, c):
238
+ scat.set_offsets(x[:, :2])
239
+ scat.set_3d_properties(x[:, 2], "z")
240
+ scat.set_facecolors([c])
241
+
242
+
243
+ def get_axrange(poses):
244
+ pose = poses[0]
245
+ x_min = pose[:, 0].min()
246
+ x_max = pose[:, 0].max()
247
+
248
+ y_min = pose[:, 1].min()
249
+ y_max = pose[:, 1].max()
250
+
251
+ z_min = pose[:, 2].min()
252
+ z_max = pose[:, 2].max()
253
+
254
+ xdiff = x_max - x_min
255
+ ydiff = y_max - y_min
256
+ zdiff = z_max - z_min
257
+
258
+ biggestdiff = max([xdiff, ydiff, zdiff])
259
+ return biggestdiff
260
+
261
+
262
+ def plot_single_pose(num, poses, lines, ax, axrange, scat, contact, ske_parents):
263
+ pose = poses[num]
264
+ static = contact[num]
265
+ indices = [7, 8, 10, 11]
266
+
267
+ for i, (point, idx) in enumerate(zip(scat, indices)):
268
+ position = pose[idx : idx + 1]
269
+ color = "r" if static[i] else "g"
270
+ set_scatter_data_3d(point, position, color)
271
+
272
+ for i, (p, line) in enumerate(zip(ske_parents, lines)):
273
+ # don't plot root
274
+ if i == 0:
275
+ continue
276
+ # stack to create a line
277
+ data = np.stack((pose[i], pose[p]), axis=0)
278
+ set_line_data_3d(line, data)
279
+
280
+ if num == 0:
281
+ if isinstance(axrange, int):
282
+ axrange = (axrange, axrange, axrange)
283
+ xcenter, ycenter, zcenter = 0, 0, 2.5
284
+ stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2
285
+
286
+ x_min, x_max = xcenter - stepx, xcenter + stepx
287
+ y_min, y_max = ycenter - stepy, ycenter + stepy
288
+ z_min, z_max = zcenter - stepz, zcenter + stepz
289
+
290
+ ax.set_xlim(x_min, x_max)
291
+ ax.set_ylim(y_min, y_max)
292
+ ax.set_zlim(z_min, z_max)
293
+
294
+
295
+ def skeleton_render(
296
+ poses,
297
+ epoch=0,
298
+ out="renders",
299
+ name="",
300
+ sound=True,
301
+ stitch=False,
302
+ sound_folder="ood_sliced",
303
+ contact=None,
304
+ render=True,
305
+ smpl_mode="smpl", # 是否渲染双手
306
+ ):
307
+ if render:
308
+ if smpl_mode=="smpl":
309
+ poses = np.concatenate((poses[:, :23, :], np.expand_dims(poses[:, 37, :], axis=1)), axis=1)
310
+ ske_parents = smpl_parents
311
+ elif smpl_mode == "smplx":
312
+ ske_parents = smplx_parents
313
+
314
+ # generate the pose with FK
315
+ Path(out).mkdir(parents=True, exist_ok=True)
316
+ num_steps = poses.shape[0] #
317
+
318
+ fig = plt.figure()
319
+ ax = fig.add_subplot(projection="3d")
320
+
321
+ point = np.array([0, 0, 1])
322
+ normal = np.array([0, 0, 1])
323
+ d = -point.dot(normal)
324
+ xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2), np.linspace(-1.5, 1.5, 2))
325
+ z = (-normal[0] * xx - normal[1] * yy - d) * 1.0 / normal[2]
326
+ # plot the plane
327
+ ax.plot_surface(xx, yy, z, zorder=-11, cmap=cm.twilight)
328
+ # Create lines initially without data
329
+ lines = [
330
+ ax.plot([], [], [], zorder=10, linewidth=1.5)[0]
331
+ for _ in ske_parents
332
+ ]
333
+ scat = [
334
+ ax.scatter([], [], [], zorder=10, s=0, cmap=ListedColormap(["r", "g", "b"]))
335
+ for _ in range(4)
336
+ ]
337
+ axrange = 3
338
+
339
+ # create contact labels
340
+ feet = poses[:, (7, 8, 10, 11)]
341
+ feetv = np.zeros(feet.shape[:2])
342
+ feetv[:-1] = np.linalg.norm(feet[1:] - feet[:-1], axis=-1)
343
+ if contact is None:
344
+ contact = feetv < 0.01
345
+ else:
346
+ contact = contact > 0.95
347
+
348
+ # Creating the Animation object
349
+ anim = animation.FuncAnimation(
350
+ fig,
351
+ plot_single_pose,
352
+ num_steps,
353
+ fargs=(poses, lines, ax, axrange, scat, contact, ske_parents),
354
+ interval=1000 // 30,
355
+ )
356
+ if sound:
357
+ # make a temporary directory to save the intermediate gif in
358
+ if render:
359
+ temp_dir = TemporaryDirectory()
360
+ gifname = os.path.join(temp_dir.name, f"{epoch}.gif")
361
+ anim.save(gifname)
362
+
363
+ # stitch wavs
364
+ if stitch:
365
+ assert type(name) == list # must be a list of names to do stitching
366
+ name_ = [os.path.splitext(x)[0] + ".wav" for x in name]
367
+ audio, sr = lr.load(name_[0], sr=None)
368
+ ll, half = len(audio), len(audio) // 2
369
+ total_wav = np.zeros(ll + half * (len(name_) - 1))
370
+ total_wav[:ll] = audio
371
+ idx = ll
372
+ for n_ in name_[1:]:
373
+ audio, sr = lr.load(n_, sr=None)
374
+ total_wav[idx : idx + half] = audio[half:]
375
+ idx += half
376
+ # save a dummy spliced audio
377
+ audioname = f"{temp_dir.name}/tempsound.wav" if render else os.path.join(out, f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.wav')
378
+ sf.write(audioname, total_wav, sr)
379
+ outname = os.path.join(
380
+ out,
381
+ f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.mp4',
382
+ )
383
+ else:
384
+ assert type(name) == str
385
+ assert name != "", "Must provide an audio filename"
386
+ audioname = name
387
+ outname = os.path.join(
388
+ out, f"{epoch}_{os.path.splitext(os.path.basename(name))[0]}.mp4"
389
+ )
390
+ if render:
391
+ print(f"ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}")
392
+ out = os.system(
393
+ f"/home/lrh/Documents/ffmpeg-6.0-amd64-static/ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}"
394
+ )
395
+ else:
396
+ if render:
397
+ # actually save the gif
398
+ path = os.path.normpath(name)
399
+ pathparts = path.split(os.sep)
400
+ gifname = os.path.join(out, f"{pathparts[-1][:-4]}.gif")
401
+ anim.save(gifname, savefig_kwargs={"transparent": True, "facecolor": "none"},)
402
+ plt.close()
403
+
404
+
405
+ class SMPLSkeleton:
406
+ def __init__(
407
+ self, device=None,
408
+ ):
409
+ offsets = smpl_offsets
410
+ parents = smpl_parents
411
+ assert len(offsets) == len(parents)
412
+
413
+ self._offsets = torch.Tensor(offsets) #.to(device)
414
+ self._parents = np.array(parents)
415
+ self._compute_metadata()
416
+
417
+ def _compute_metadata(self):
418
+ self._has_children = np.zeros(len(self._parents)).astype(bool)
419
+ for i, parent in enumerate(self._parents):
420
+ if parent != -1:
421
+ self._has_children[parent] = True
422
+
423
+ self._children = []
424
+ for i, parent in enumerate(self._parents):
425
+ self._children.append([])
426
+ for i, parent in enumerate(self._parents):
427
+ if parent != -1:
428
+ self._children[parent].append(i)
429
+
430
+ def forward(self, rotations, root_positions):
431
+ """
432
+ Perform forward kinematics using the given trajectory and local rotations.
433
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
434
+ -- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint.
435
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
436
+ """
437
+ assert len(rotations.shape) == 4
438
+ assert len(root_positions.shape) == 3
439
+ # transform from axis angle to quaternion
440
+ fk_device = rotations.device
441
+ self._offsets.to(fk_device)
442
+ rotations = axis_angle_to_quaternion(rotations)
443
+
444
+ positions_world = []
445
+ rotations_world = []
446
+
447
+ expanded_offsets = self._offsets.expand(
448
+ rotations.shape[0],
449
+ rotations.shape[1],
450
+ self._offsets.shape[0],
451
+ self._offsets.shape[1],
452
+ ).to(fk_device)
453
+
454
+ # Parallelize along the batch and time dimensions
455
+ for i in range(self._offsets.shape[0]):
456
+ if self._parents[i] == -1:
457
+ positions_world.append(root_positions)
458
+ rotations_world.append(rotations[:, :, 0])
459
+ else:
460
+ positions_world.append(
461
+ quaternion_apply(
462
+ rotations_world[self._parents[i]], expanded_offsets[:, :, i]
463
+ )
464
+ + positions_world[self._parents[i]]
465
+ )
466
+ if self._has_children[i]:
467
+ rotations_world.append(
468
+ quaternion_multiply(
469
+ rotations_world[self._parents[i]], rotations[:, :, i]
470
+ )
471
+ )
472
+ else:
473
+ # This joint is a terminal node -> it would be useless to compute the transformation
474
+ rotations_world.append(None)
475
+
476
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
477
+
478
+
479
+ @torch.no_grad()
480
+ class SMPLX_Skeleton:
481
+ def __init__(
482
+ self, device=None, batch=64,
483
+ ):
484
+ # offsets = smpl_offsets
485
+ self.device = device
486
+ self.parents = smplx_parents
487
+ self.J = np.load(os.path.join(os.path.dirname(__file__), "smplx_neu_J_1.npy"))
488
+ self.J = torch.from_numpy(self.J).to(device).unsqueeze(dim=0).repeat(batch, 1, 1)
489
+
490
+ def batch_rodrigues(self, rot_vecs: Tensor, epsilon: float = 1e-8,) -> Tensor:
491
+ ''' Calculates the rotation matrices for a batch of rotation vectors
492
+ Parameters
493
+ ----------
494
+ rot_vecs: torch.tensor Nx3
495
+ array of N axis-angle vectors
496
+ Returns
497
+ -------
498
+ R: torch.tensor Nx3x3
499
+ The rotation matrices for the given axis-angle parameters
500
+ '''
501
+ batch_size = rot_vecs.shape[0]
502
+ device, dtype = rot_vecs.device, rot_vecs.dtype
503
+
504
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
505
+ rot_dir = rot_vecs / angle
506
+
507
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
508
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
509
+
510
+ # Bx1 arrays
511
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
512
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
513
+
514
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
515
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
516
+ .view((batch_size, 3, 3))
517
+
518
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
519
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
520
+ return rot_mat
521
+
522
+ def batch_rigid_transform(self,
523
+ rot_mats: Tensor,
524
+ joints: Tensor,
525
+ parents: Tensor,
526
+ dtype=torch.float32
527
+ ) -> Tensor:
528
+ """
529
+ Applies a batch of rigid transformations to the joints
530
+
531
+ Parameters
532
+ ----------
533
+ rot_mats : torch.tensor BxNx3x3
534
+ Tensor of rotation matrices
535
+ joints : torch.tensor BxNx3
536
+ Locations of joints
537
+ parents : torch.tensor BxN
538
+ The kinematic tree of each object
539
+ dtype : torch.dtype, optional:
540
+ The data type of the created tensors, the default is torch.float32
541
+
542
+ Returns
543
+ -------
544
+ posed_joints : torch.tensor BxNx3
545
+ The locations of the joints after applying the pose rotations
546
+ rel_transforms : torch.tensor BxNx4x4
547
+ The relative (with respect to the root joint) rigid transformations
548
+ for all the joints
549
+ """
550
+
551
+ joints = torch.unsqueeze(joints, dim=-1)
552
+ # joints_check = joints.detach().cpu().numpy()
553
+
554
+ rel_joints = joints.clone()
555
+ rel_joints[:, 1:] -= joints[:, parents[1:]]
556
+
557
+ transforms_mat = self.transform_mat(
558
+ rot_mats.reshape(-1, 3, 3),
559
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
560
+
561
+ transform_chain = [transforms_mat[:, 0]]
562
+ for i in range(1, parents.shape[0]):
563
+ # Subtract the joint location at the rest pose
564
+ # No need for rotation, since it's identity when at rest
565
+ curr_res = torch.matmul(transform_chain[parents[i]],
566
+ transforms_mat[:, i])
567
+ transform_chain.append(curr_res)
568
+
569
+ transforms = torch.stack(transform_chain, dim=1)
570
+
571
+ # The last column of the transformations contains the posed joints
572
+ posed_joints = transforms[:, :, :3, 3]
573
+
574
+ # joints_homogen = F.pad(joints, [0, 0, 0, 1])
575
+
576
+ # rel_transforms = transforms - F.pad(
577
+ # torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
578
+
579
+ return posed_joints #, rel_transforms
580
+
581
+ def transform_mat(self, R: Tensor, t: Tensor) -> Tensor:
582
+ ''' Creates a batch of transformation matrices
583
+ Args:
584
+ - R: Bx3x3 array of a batch of rotation matrices
585
+ - t: Bx3x1 array of a batch of translation vectors
586
+ Returns:
587
+ - T: Bx4x4 Transformation matrix
588
+ '''
589
+ # No padding left or right, only add an extra row
590
+ return torch.cat([F.pad(R, [0, 0, 0, 1]),
591
+ F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
592
+
593
+ def motion_data_load_process(self, motionfile):
594
+ if motionfile.split(".")[-1] == "pkl":
595
+ pkl_data = pickle.load(open(motionfile, "rb"))
596
+ if "pos" in pkl_data.keys():
597
+ local_q_165 = torch.from_numpy(pkl_data["q"]).to(self.device).float()
598
+ root_pos = torch.from_numpy(pkl_data["pos"]).to(self.device).float()
599
+ root_pos = root_pos[:, :] - root_pos[0, :]
600
+ return local_q_165, root_pos
601
+ else:
602
+ smpl_poses = pkl_data["smpl_poses"]
603
+ if smpl_poses.shape[0] != 150 and smpl_poses.shape[0] != 300:
604
+ smpl_poses = smpl_poses.reshape(150, -1)
605
+ # modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1)
606
+ # assert modata.shape[1] == 159
607
+ # modata = torch.from_numpy(modata).to(f'cuda:{args.gpu}')
608
+ root_pos = pkl_data["smpl_trans"]
609
+
610
+ local_q = torch.from_numpy(smpl_poses).to(self.device).float()
611
+ root_pos = torch.from_numpy(root_pos).to(self.device).float()
612
+ local_q_165 = torch.cat([local_q[:, :66], torch.zeros([local_q.shape[0], 9], device=local_q.device, dtype=torch.float32), local_q[:, 66:]], dim=1).to(self.device).float()
613
+ root_pos = root_pos[:, :] - root_pos[0, :]
614
+ return local_q_165, root_pos
615
+
616
+
617
+ def forward(self, rotations, root_positions):
618
+ """
619
+ Perform forward kinematics using the given trajectory and local rotations.
620
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
621
+ -- rotations: (N, 156) 或 (N, 165)
622
+ -- root_positions: (N, 3)
623
+ 输出: N, 55, 3 关节点全局坐标
624
+ """
625
+ # assert len(rotations.shape) == 4
626
+ # assert len(root_positions.shape) == 3
627
+ # print(fk_device)
628
+ fk_device = rotations.device
629
+ if rotations.shape[1] == 156:
630
+ local_q_165 = torch.cat([rotations[:, :66], torch.zeros([rotations.shape[0], 9], device=fk_device, dtype=torch.float32), rotations[:, 66:]], dim=1).to(fk_device).float()
631
+ elif rotations.shape[1] == 165:
632
+ local_q_165 = rotations.to(fk_device).float()
633
+ else:
634
+ print("rotations shape error", rotations.shape)
635
+ sys.exit(0)
636
+
637
+ root_pos = root_positions.to(fk_device).float()
638
+ assert local_q_165.shape[1] == 165
639
+
640
+
641
+ B, C = local_q_165.shape
642
+ # print("local_q shape is:", local_q_165.shape)
643
+ rot_mats = self.batch_rodrigues(local_q_165.view(-1, 3)).view(
644
+ [B, -1, 3, 3])
645
+ # J = np.load("/data/lrh/project/Dance/mdm_v2/model/smplx_neu_J_1.npy")
646
+
647
+ if self.J.shape[0] >= B:
648
+ J_temp = self.J[:B,:,:] #self.J = self.J[:B,:,:]
649
+ else:
650
+ J_temp = self.J[:1,:,:].repeat(B, 1, 1)
651
+ print("warning: self.J size 0 is lower than batchsize x seq_len")
652
+
653
+ parents = torch.Tensor(self.parents).long() # if self.parents is None else self.parents
654
+ J_transformed = self.batch_rigid_transform(rot_mats, J_temp, parents, dtype=torch.float32)
655
+ J_transformed += root_pos.unsqueeze(dim=1)
656
+ # J_transformed = J_transformed.detach().cpu().numpy()
657
+
658
+ return J_transformed
659
+
660
+
661
+ if __name__ == "__main__":
662
+ print("1")
663
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
664
+
665
+
666
+ smplx_fk = SMPLX_Skeleton(device = device, batch=150)
667
+ motion_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/dances/012_slice0.pkl"
668
+ # music_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/wavs/012_slice0.wav"
669
+ local_q_165, root_pos = smplx_fk.motion_data_load_process(motion_file)
670
+ print("local_q_165.shape", local_q_165.shape)
671
+ print("root_pos.shape", root_pos.shape)
672
+
673
+
674
+ joints = smplx_fk.forward(local_q_165, root_pos).detach().cpu().numpy() # 150, 165 150, 3
675
+
676
+ print("joints.shape", joints.shape)
677
+ # skeleton_render(
678
+ # joints,
679
+ # epoch=f"e{1}_b{1}",
680
+ # out="./output/temp",
681
+ # name=music_file,
682
+ # render=True,
683
+ # stitch=False,
684
+ # sound=True,
685
+ # smpl_mode="smplx"
686
+ # )
687
+