Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +187 -0
- LICENSE +52 -0
- README.md +164 -3
- app.py +84 -0
- args.py +119 -0
- data/code/pre_motion.py +72 -0
- data/code/pre_music.py +90 -0
- data/code/slice_music_motion.py +41 -0
- dataset/FineDance_dataset.py +180 -0
- dataset/__init__.py +0 -0
- dataset/preprocess.py +93 -0
- dataset/quaternion.py +71 -0
- dataset/scaler.py +83 -0
- environment.yaml +343 -0
- environment_macos.yaml +64 -0
- generate_all.py +57 -0
- generate_dance.py +240 -0
- model/adan.py +123 -0
- model/diffusion.py +741 -0
- model/model.py +444 -0
- model/rotary_embedding_torch.py +132 -0
- model/utils.py +99 -0
- render.py +395 -0
- smplx_neu_J_1.npy +3 -0
- teaser/teaser.png +3 -0
- test.py +187 -0
- train_seq.py +318 -0
- vis.py +687 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
teaser/teaser.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# Project-specific
|
| 7 |
+
experiments
|
| 8 |
+
data/finedance
|
| 9 |
+
generated
|
| 10 |
+
eval
|
| 11 |
+
wandb
|
| 12 |
+
assets/checkpoints/
|
| 13 |
+
assets/smpl_model/
|
| 14 |
+
assets/ffmpeg-6.0-amd64-static/
|
| 15 |
+
assets/NORMAL_new.obj
|
| 16 |
+
|
| 17 |
+
# User-generated output and uploaded music
|
| 18 |
+
output/
|
| 19 |
+
custom_music/
|
| 20 |
+
*.mp4
|
| 21 |
+
*.wav
|
| 22 |
+
*.mp3
|
| 23 |
+
*.flac
|
| 24 |
+
*.ogg
|
| 25 |
+
*.m4a
|
| 26 |
+
|
| 27 |
+
# macOS
|
| 28 |
+
.DS_Store
|
| 29 |
+
|
| 30 |
+
# VSCode
|
| 31 |
+
.vscode/
|
| 32 |
+
|
| 33 |
+
# C extensions
|
| 34 |
+
*.so
|
| 35 |
+
|
| 36 |
+
# Distribution / packaging
|
| 37 |
+
.Python
|
| 38 |
+
build/
|
| 39 |
+
develop-eggs/
|
| 40 |
+
dist/
|
| 41 |
+
downloads/
|
| 42 |
+
eggs/
|
| 43 |
+
.eggs/
|
| 44 |
+
lib/
|
| 45 |
+
lib64/
|
| 46 |
+
parts/
|
| 47 |
+
sdist/
|
| 48 |
+
var/
|
| 49 |
+
wheels/
|
| 50 |
+
share/python-wheels/
|
| 51 |
+
*.egg-info/
|
| 52 |
+
.installed.cfg
|
| 53 |
+
*.egg
|
| 54 |
+
MANIFEST
|
| 55 |
+
|
| 56 |
+
# PyInstaller
|
| 57 |
+
# Usually these files are written by a python script from a template
|
| 58 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 59 |
+
*.manifest
|
| 60 |
+
*.spec
|
| 61 |
+
|
| 62 |
+
# Installer logs
|
| 63 |
+
pip-log.txt
|
| 64 |
+
pip-delete-this-directory.txt
|
| 65 |
+
|
| 66 |
+
# Unit test / coverage reports
|
| 67 |
+
htmlcov/
|
| 68 |
+
.tox/
|
| 69 |
+
.nox/
|
| 70 |
+
.coverage
|
| 71 |
+
.coverage.*
|
| 72 |
+
.cache
|
| 73 |
+
nosetests.xml
|
| 74 |
+
coverage.xml
|
| 75 |
+
*.cover
|
| 76 |
+
*.py,cover
|
| 77 |
+
.hypothesis/
|
| 78 |
+
.pytest_cache/
|
| 79 |
+
cover/
|
| 80 |
+
|
| 81 |
+
# Translations
|
| 82 |
+
*.mo
|
| 83 |
+
*.pot
|
| 84 |
+
|
| 85 |
+
# Django stuff:
|
| 86 |
+
*.log
|
| 87 |
+
local_settings.py
|
| 88 |
+
db.sqlite3
|
| 89 |
+
db.sqlite3-journal
|
| 90 |
+
|
| 91 |
+
# Flask stuff:
|
| 92 |
+
instance/
|
| 93 |
+
.webassets-cache
|
| 94 |
+
|
| 95 |
+
# Scrapy stuff:
|
| 96 |
+
.scrapy
|
| 97 |
+
|
| 98 |
+
# Sphinx documentation
|
| 99 |
+
docs/_build/
|
| 100 |
+
|
| 101 |
+
# PyBuilder
|
| 102 |
+
.pybuilder/
|
| 103 |
+
target/
|
| 104 |
+
|
| 105 |
+
# Jupyter Notebook
|
| 106 |
+
.ipynb_checkpoints
|
| 107 |
+
|
| 108 |
+
# IPython
|
| 109 |
+
profile_default/
|
| 110 |
+
ipython_config.py
|
| 111 |
+
|
| 112 |
+
# pyenv
|
| 113 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 114 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 115 |
+
# .python-version
|
| 116 |
+
|
| 117 |
+
# pipenv
|
| 118 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 119 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 120 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 121 |
+
# install all needed dependencies.
|
| 122 |
+
#Pipfile.lock
|
| 123 |
+
|
| 124 |
+
# poetry
|
| 125 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 126 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 127 |
+
# commonly ignored for libraries.
|
| 128 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 129 |
+
#poetry.lock
|
| 130 |
+
|
| 131 |
+
# pdm
|
| 132 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 133 |
+
#pdm.lock
|
| 134 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 135 |
+
# in version control.
|
| 136 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 137 |
+
.pdm.toml
|
| 138 |
+
|
| 139 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 140 |
+
__pypackages__/
|
| 141 |
+
|
| 142 |
+
# Celery stuff
|
| 143 |
+
celerybeat-schedule
|
| 144 |
+
celerybeat.pid
|
| 145 |
+
|
| 146 |
+
# SageMath parsed files
|
| 147 |
+
*.sage.py
|
| 148 |
+
|
| 149 |
+
# Environments
|
| 150 |
+
.env
|
| 151 |
+
.venv
|
| 152 |
+
env/
|
| 153 |
+
venv/
|
| 154 |
+
ENV/
|
| 155 |
+
env.bak/
|
| 156 |
+
venv.bak/
|
| 157 |
+
|
| 158 |
+
# Spyder project settings
|
| 159 |
+
.spyderproject
|
| 160 |
+
.spyproject
|
| 161 |
+
|
| 162 |
+
# Rope project settings
|
| 163 |
+
.ropeproject
|
| 164 |
+
|
| 165 |
+
# mkdocs documentation
|
| 166 |
+
/site
|
| 167 |
+
|
| 168 |
+
# mypy
|
| 169 |
+
.mypy_cache/
|
| 170 |
+
.dmypy.json
|
| 171 |
+
dmypy.json
|
| 172 |
+
|
| 173 |
+
# Pyre type checker
|
| 174 |
+
.pyre/
|
| 175 |
+
|
| 176 |
+
# pytype static type analyzer
|
| 177 |
+
.pytype/
|
| 178 |
+
|
| 179 |
+
# Cython debug symbols
|
| 180 |
+
cython_debug/
|
| 181 |
+
|
| 182 |
+
# PyCharm
|
| 183 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 184 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 185 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 186 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 187 |
+
#.idea/
|
LICENSE
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
License
|
| 2 |
+
Software Copyright License for non-commercial scientific research purposes
|
| 3 |
+
Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use FineDance data, model and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
|
| 4 |
+
|
| 5 |
+
Ownership / Licensees
|
| 6 |
+
The Software and the associated materials has been developed at the
|
| 7 |
+
|
| 8 |
+
Professor Xiu Li, Tsinghua University.
|
| 9 |
+
|
| 10 |
+
Any copyright or patent right is owned by and proprietary material of the
|
| 11 |
+
|
| 12 |
+
Professor Xiu Li, Tsinghua University.
|
| 13 |
+
|
| 14 |
+
hereinafter the “Licensor”.
|
| 15 |
+
|
| 16 |
+
License Grant
|
| 17 |
+
Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
|
| 18 |
+
|
| 19 |
+
To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
|
| 20 |
+
To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
|
| 21 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Xiu Li’s prior written permission.
|
| 22 |
+
|
| 23 |
+
The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
|
| 24 |
+
|
| 25 |
+
No Distribution
|
| 26 |
+
The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
|
| 27 |
+
|
| 28 |
+
Disclaimer of Representations and Warranties
|
| 29 |
+
You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
|
| 30 |
+
|
| 31 |
+
Limitation of Liability
|
| 32 |
+
The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
|
| 33 |
+
|
| 34 |
+
No Maintenance Services
|
| 35 |
+
You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
|
| 36 |
+
|
| 37 |
+
Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
|
| 38 |
+
|
| 39 |
+
Publications using the Data & Software
|
| 40 |
+
You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
|
| 41 |
+
|
| 42 |
+
Citation:
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@InProceedings{Li_2023_ICCV,
|
| 46 |
+
author = {Li, Ronghui and Zhao, Junfan and Zhang, Yachao and Su, Mingyang and Ren, Zeping and Zhang, Han and Tang, Yansong and Li, Xiu},
|
| 47 |
+
title = {FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation},
|
| 48 |
+
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
| 49 |
+
month = {October},
|
| 50 |
+
year = {2023},
|
| 51 |
+
pages = {10234-10243}
|
| 52 |
+
}
|
README.md
CHANGED
|
@@ -1,3 +1,164 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation (ICCV 2023)](https://github.com/li-ronghui/FineDance)
|
| 2 |
+
|
| 3 |
+
[[Project Page](https://li-ronghui.github.io/finedance)] | [[Preprint](https://arxiv.org/abs/2212.03741)] | [[pdf](https://arxiv.org/pdf/2212.03741.pdf)] | [[video](https://li-ronghui.github.io/finedance)]
|
| 4 |
+
|
| 5 |
+
<img src="teaser/teaser.png">
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
### Prerequisites
|
| 10 |
+
|
| 11 |
+
Install the conda environment and activate it:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
conda env create -f environment.yaml
|
| 15 |
+
conda activate FineNet
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Download the pretrained checkpoints and asset files from [Google Drive](https://drive.google.com/file/d/1ENoeUn-X-3Vw2Gon-voVLlndy3hZXdWD/view?usp=drive_link).
|
| 19 |
+
|
| 20 |
+
### Web UI (Recommended)
|
| 21 |
+
|
| 22 |
+
Launch the Gradio web interface:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
python app.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Open `http://127.0.0.1:7861` in your browser. Upload a music file and click "Generate Dance" to produce a video.
|
| 29 |
+
|
| 30 |
+
### Command Line
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
python generate_dance.py /path/to/music.mp3
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Output will be saved to `output/<songname>_dance.mp4`.
|
| 37 |
+
|
| 38 |
+
To specify a custom output path:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python generate_dance.py /path/to/music.mp3 --output my_dance.mp4
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Supported audio formats: any format ffmpeg can read (`.mp3`, `.wav`, `.m4a`, `.flac`, `.ogg`, etc.).
|
| 45 |
+
|
| 46 |
+
### Output Details
|
| 47 |
+
|
| 48 |
+
- Resolution: 1200x1200
|
| 49 |
+
- Frame rate: 30 fps
|
| 50 |
+
- Duration: ~30 seconds
|
| 51 |
+
- Background: black
|
| 52 |
+
- Body model: SMPLX (full body with hands)
|
| 53 |
+
|
| 54 |
+
## How It Works
|
| 55 |
+
|
| 56 |
+
1. **Audio conversion** - Converts input to WAV if needed
|
| 57 |
+
2. **Feature extraction** - Slices audio into 4-second windows with 2-second stride, then extracts 35-dim features per slice (onset envelope, 20 MFCC, 12 chroma, peak onehot, beat onehot) using librosa
|
| 58 |
+
3. **Dance generation** - Feeds audio features into a pretrained diffusion model (`assets/checkpoints/train-2000.pt`) which generates SMPLX body motion (319-dim: 4 contact + 3 translation + 52 joints x 6 rotation)
|
| 59 |
+
4. **Rendering** - Converts generated motion to SMPLX meshes and renders 900 frames at 30fps using pyrender
|
| 60 |
+
5. **Final output** - Merges rendered video with original audio via ffmpeg
|
| 61 |
+
|
| 62 |
+
## FineDance Dataset
|
| 63 |
+
|
| 64 |
+
The dataset (7.7 hours) can be downloaded from [Google Drive](https://drive.google.com/file/d/1zQvWG9I0H4U3Zrm8d_QD_ehenZvqfQfS/view?usp=sharing) or [Baidu Cloud](https://pan.baidu.com/s/1gynUC7pMdpsE31wAwq177w?pwd=o9pw).
|
| 65 |
+
|
| 66 |
+
Put the downloaded data into `./data`. The data directory contains:
|
| 67 |
+
|
| 68 |
+
- **label_json** - Song name, coarse style, and fine-grained genre
|
| 69 |
+
- **motion** - [SMPLH](https://smpl-x.is.tue.mpg.de/) format motion data
|
| 70 |
+
- **music_wav** - Music data in WAV format
|
| 71 |
+
- **music_npy** - Music features extracted by [librosa](https://github.com/librosa/librosa) following [AIST++](https://github.com/google/aistplusplus_api/tree/main)
|
| 72 |
+
|
| 73 |
+
Reading a motion file:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
import numpy as np
|
| 77 |
+
data = np.load("motion/001.npy")
|
| 78 |
+
T, C = data.shape # T is the number of frames
|
| 79 |
+
smpl_poses = data[:, 3:]
|
| 80 |
+
smpl_trans = data[:, :3]
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Dataset Split
|
| 84 |
+
|
| 85 |
+
The dataset is split into train, val, and test sets in two ways:
|
| 86 |
+
|
| 87 |
+
1. **FineDance@Genre** - Test set includes a broader range of dance genres; the same dancer may appear across splits but with different motions. Recommended for dance generation.
|
| 88 |
+
2. **FineDance@Dancer** - Splits are divided by dancer; the same dancer won't appear in different sets, but the test set contains fewer genres.
|
| 89 |
+
|
| 90 |
+
## Training
|
| 91 |
+
|
| 92 |
+
Only needed if you want to train from scratch. The pretrained checkpoint is already provided.
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
# Data preprocessing
|
| 96 |
+
python data/code/pre_motion.py
|
| 97 |
+
|
| 98 |
+
# Train
|
| 99 |
+
accelerate launch train_seq.py --batch_size 32 --epochs 200
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Key flags:
|
| 103 |
+
- `--batch_size` - Default is 400, reduce to 32 or lower for Mac MPS (limited to ~30GB)
|
| 104 |
+
- `--epochs` - Default is 2000
|
| 105 |
+
- `--checkpoint` - Resume from a saved checkpoint
|
| 106 |
+
|
| 107 |
+
## Advanced Usage
|
| 108 |
+
|
| 109 |
+
### Generate on the test set
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
python data/code/slice_music_motion.py
|
| 113 |
+
python generate_all.py --motion_save_dir generated/finedance_seq_120_dancer --save_motions
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Render a pre-generated motion file
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
python render.py --modir eval/motions --mode smplx
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Project Structure
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
FineDance/
|
| 126 |
+
├── app.py # Gradio web UI
|
| 127 |
+
├── generate_dance.py # One-command dance generation (CLI)
|
| 128 |
+
├── train_seq.py # Training script
|
| 129 |
+
├── test.py # Original test/inference script
|
| 130 |
+
├── render.py # Video rendering (SMPLX mesh to MP4)
|
| 131 |
+
├── args.py # CLI argument definitions
|
| 132 |
+
├── vis.py # Skeleton/FK utilities
|
| 133 |
+
├── assets/
|
| 134 |
+
│ ├── checkpoints/
|
| 135 |
+
│ │ └── train-2000.pt # Pretrained model (2000 epochs)
|
| 136 |
+
│ └── smpl_model/
|
| 137 |
+
│ └── smplx/
|
| 138 |
+
│ └── SMPLX_NEUTRAL.npz # SMPLX body model
|
| 139 |
+
├── model/
|
| 140 |
+
│ ├── model.py # SeqModel (transformer decoder)
|
| 141 |
+
│ └── diffusion.py # Gaussian diffusion (training + sampling)
|
| 142 |
+
├── dataset/
|
| 143 |
+
│ └── FineDance_dataset.py # Dataset loader
|
| 144 |
+
└── data/
|
| 145 |
+
└── finedance/ # Training data (music + motion pairs)
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Acknowledgments
|
| 149 |
+
|
| 150 |
+
We would like to express our sincere gratitude to Dr [Yan Zhang](https://yz-cnsdqz.github.io/) and [Yulun Zhang](https://yulunzhang.com/) for their invaluable guidance and insights during the course of our research.
|
| 151 |
+
|
| 152 |
+
This code is based on: [EDGE](https://github.com/Stanford-TML/EDGE/tree/main), [MDM](https://github.com/Stanford-TML/EDGE/tree/main), [Adan](https://github.com/lucidrains/Adan-pytorch), [Diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch), [SMPLX](https://smpl-x.is.tue.mpg.de/).
|
| 153 |
+
|
| 154 |
+
## Citation
|
| 155 |
+
|
| 156 |
+
```
|
| 157 |
+
@inproceedings{li2023finedance,
|
| 158 |
+
title={FineDance: A Fine-grained Choreography Dataset for 3D Full Body Dance Generation},
|
| 159 |
+
author={Li, Ronghui and Zhao, Junfan and Zhang, Yachao and Su, Mingyang and Ren, Zeping and Zhang, Han and Tang, Yansong and Li, Xiu},
|
| 160 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
| 161 |
+
pages={10234--10243},
|
| 162 |
+
year={2023}
|
| 163 |
+
}
|
| 164 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI for FineDance — generate dance videos from music.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
conda activate FineNet
|
| 6 |
+
python app.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
# Monkey-patch gradio_client bug: additionalProperties can be a bool,
|
| 17 |
+
# but the code assumes it's always a dict.
|
| 18 |
+
import gradio_client.utils as _gc_utils
|
| 19 |
+
|
| 20 |
+
_orig_json_schema_to_python_type = _gc_utils._json_schema_to_python_type
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _patched_json_schema_to_python_type(schema, defs=None):
|
| 24 |
+
if isinstance(schema, bool):
|
| 25 |
+
return "Any"
|
| 26 |
+
return _orig_json_schema_to_python_type(schema, defs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
|
| 30 |
+
from generate_dance import load_model, generate, _setup_render_args
|
| 31 |
+
from render import MovieMaker
|
| 32 |
+
|
| 33 |
+
# Preload model once at startup
|
| 34 |
+
print("Loading model...")
|
| 35 |
+
MODEL = load_model()
|
| 36 |
+
print("Model loaded. Initializing renderer...")
|
| 37 |
+
|
| 38 |
+
# Create MovieMaker on the main thread so pyglet's signal handler works.
|
| 39 |
+
_setup_render_args()
|
| 40 |
+
VISUALIZER = MovieMaker(save_path=".")
|
| 41 |
+
print("Starting UI...")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run(audio_path):
|
| 45 |
+
if audio_path is None:
|
| 46 |
+
raise gr.Error("Please upload a music file.")
|
| 47 |
+
|
| 48 |
+
logs = []
|
| 49 |
+
|
| 50 |
+
def log_fn(msg):
|
| 51 |
+
logs.append(msg)
|
| 52 |
+
print(msg)
|
| 53 |
+
|
| 54 |
+
songname = os.path.splitext(os.path.basename(audio_path))[0]
|
| 55 |
+
output_path = os.path.join(tempfile.gettempdir(), f"{songname}_dance.mp4")
|
| 56 |
+
|
| 57 |
+
generate(audio_path, output_path, model=MODEL, visualizer=VISUALIZER, log_fn=log_fn)
|
| 58 |
+
|
| 59 |
+
return output_path, "\n".join(logs)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
with gr.Blocks(title="FineDance") as demo:
|
| 63 |
+
gr.Markdown("# FineDance\nUpload a music file to generate a 3D dance video.")
|
| 64 |
+
|
| 65 |
+
with gr.Row():
|
| 66 |
+
with gr.Column():
|
| 67 |
+
audio_input = gr.Audio(
|
| 68 |
+
label="Upload Music",
|
| 69 |
+
type="filepath",
|
| 70 |
+
)
|
| 71 |
+
generate_btn = gr.Button("Generate Dance", variant="primary")
|
| 72 |
+
|
| 73 |
+
with gr.Column():
|
| 74 |
+
video_output = gr.Video(label="Generated Dance")
|
| 75 |
+
status_output = gr.Textbox(label="Status", lines=6, interactive=False)
|
| 76 |
+
|
| 77 |
+
generate_btn.click(
|
| 78 |
+
fn=run,
|
| 79 |
+
inputs=[audio_input],
|
| 80 |
+
outputs=[video_output, status_output],
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
demo.launch(server_name="127.0.0.1")
|
args.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
def FineDance_parse_train_opt():
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument("--project", default="experiments/finedance_seq_120_genre/train", help="project/name")
|
| 7 |
+
parser.add_argument("--exp_name", default="finedance_seq_120_genre", help="save to project/name")
|
| 8 |
+
parser.add_argument("--feature_type", type=str, default="baseline")
|
| 9 |
+
parser.add_argument("--datasplit", type=str, default="cross_genre", choices=["cross_genre", "cross_dancer"])
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--render_dir", type=str, default="experiments/finedance_seq_120_genre/renders", help="Sample render path"
|
| 12 |
+
)
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--full_seq_len", type=int, default=120, help="full_seq_len"
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--windows", type=int, default=10, help="windows"
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--mix", action="store_true", help="Saves the motions for evaluation"
|
| 21 |
+
)
|
| 22 |
+
# parser.add_argument("--feature_type", type=str, default="jukebox")
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--wandb_pj_name", type=str, default="finedance_seq", help="project name"
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument("--batch_size", type=int, default=400, help="batch size") # default=64
|
| 27 |
+
parser.add_argument("--epochs", type=int, default=2000)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--save_interval",
|
| 30 |
+
type=int,
|
| 31 |
+
default=10, # default=100,
|
| 32 |
+
help='Log model after every "save_period" epoch',
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument("--ema_interval", type=int, default=1, help="ema every x steps")
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--checkpoint", type=str, default="", help="trained checkpoint path (optional)"
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--do_normalize",
|
| 40 |
+
action="store_true",
|
| 41 |
+
help="normalize",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--nfeats", type=int, default=319, help="nfeats"
|
| 45 |
+
)
|
| 46 |
+
opt = parser.parse_args()
|
| 47 |
+
return opt
|
| 48 |
+
|
| 49 |
+
def FineDance_parse_test_opt():
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
parser.add_argument("--feature_type", type=str, default="baseline")
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--full_seq_len", type=int, default=120, help="full_seq_len"
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument("--datasplit", type=str, default="cross_genre", choices=["cross_genre", "cross_dancer"])
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--windows", type=int, default=10, help="windows"
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument("--out_length", type=float, default=30, help="max. length of output, in seconds")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--render_dir", type=str, default="FineDance_test_renders/", help="Sample render path"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--checkpoint", type=str, default="assets/checkpoints/train-2000.pt", help="checkpoint"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--nfeats", type=int, default=319, help="nfeats"
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--music_dir",
|
| 71 |
+
type=str,
|
| 72 |
+
default="data/finedance/music_wav",
|
| 73 |
+
help="folder containing input music",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--save_motions", action="store_true", help="Saves the motions for evaluation"
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--motion_save_dir",
|
| 80 |
+
type=str,
|
| 81 |
+
default="eval/motions",
|
| 82 |
+
help="Where to save the motions",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--cache_features",
|
| 86 |
+
action="store_true",
|
| 87 |
+
help="Save the jukebox features for later reuse",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--do_normalize",
|
| 91 |
+
action="store_true",
|
| 92 |
+
help="normalize",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--no_render",
|
| 96 |
+
action="store_true",
|
| 97 |
+
help="Don't render the video",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--use_cached_features",
|
| 101 |
+
action="store_true",
|
| 102 |
+
help="Use precomputed features instead of music folder",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--feature_cache_dir",
|
| 106 |
+
type=str,
|
| 107 |
+
default="cached_features/",
|
| 108 |
+
help="Where to save/load the features",
|
| 109 |
+
)
|
| 110 |
+
opt = parser.parse_args()
|
| 111 |
+
return opt
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def save_arguments_to_yaml(args, file_path):
|
| 115 |
+
arg_dict = vars(args) # 将Namespace对象转换为字典
|
| 116 |
+
yaml_str = yaml.dump(arg_dict, default_flow_style=False)
|
| 117 |
+
|
| 118 |
+
with open(file_path, 'w') as file:
|
| 119 |
+
file.write(yaml_str)
|
data/code/pre_motion.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import smplx, pickle
|
| 5 |
+
import torch
|
| 6 |
+
import sys
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import glob
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
sys.path.append(os.getcwd())
|
| 12 |
+
from dataset.quaternion import ax_to_6v, ax_from_6v
|
| 13 |
+
from dataset.preprocess import Normalizer, vectorize_many
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def motion_feats_extract(inputs_dir, outputs_dir):
|
| 17 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
print("extracting")
|
| 19 |
+
raw_fps = 30
|
| 20 |
+
data_fps = 30
|
| 21 |
+
data_fps <= raw_fps
|
| 22 |
+
if not os.path.exists(outputs_dir):
|
| 23 |
+
os.makedirs(outputs_dir)
|
| 24 |
+
# All motion is retargeted to this standard model.
|
| 25 |
+
smplx_model = smplx.SMPLX(model_path='assets/smpl_model/smplx', ext='npz', gender='neutral',
|
| 26 |
+
num_betas=10, flat_hand_mean=True, num_expression_coeffs=10, use_pca=False).eval().to(device)
|
| 27 |
+
|
| 28 |
+
motions = sorted(glob.glob(os.path.join(inputs_dir, "*.npy")))
|
| 29 |
+
for motion in tqdm(motions):
|
| 30 |
+
name = os.path.splitext(os.path.basename(motion))[0].split(".")[0]
|
| 31 |
+
print("name is", name)
|
| 32 |
+
data = np.load(motion, allow_pickle=True)
|
| 33 |
+
print(data.shape)
|
| 34 |
+
pos = data[:,:3] # length, c
|
| 35 |
+
q = data[:,3:]
|
| 36 |
+
root_pos = torch.Tensor(pos).to(device) # T, 3
|
| 37 |
+
length = root_pos.shape[0]
|
| 38 |
+
local_q_rot6d = torch.Tensor(q).to(device) # T, 312
|
| 39 |
+
print("local_q_rot6d", local_q_rot6d.shape)
|
| 40 |
+
local_q = local_q_rot6d.reshape(length, 52, 6).clone()
|
| 41 |
+
local_q = ax_from_6v(local_q).view(length, 156) # T, 156
|
| 42 |
+
|
| 43 |
+
smplx_output = smplx_model(
|
| 44 |
+
betas = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
|
| 45 |
+
transl = root_pos, # global translation
|
| 46 |
+
global_orient = local_q[:, :3],
|
| 47 |
+
body_pose = local_q[:, 3:66], # 21
|
| 48 |
+
jaw_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
|
| 49 |
+
leye_pose = torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
|
| 50 |
+
reye_pose= torch.zeros([root_pos.shape[0], 3], device=device, dtype=torch.float32), # 1
|
| 51 |
+
left_hand_pose = local_q[:, 66:66+45], # 15
|
| 52 |
+
right_hand_pose = local_q[:, 66+45:], # 15
|
| 53 |
+
expression = torch.zeros([root_pos.shape[0], 10], device=device, dtype=torch.float32),
|
| 54 |
+
return_verts = False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
positions = smplx_output.joints.view(length, -1, 3) # bxt, j, 3
|
| 59 |
+
feet = positions[:, (7, 8, 10, 11)] # # 150, 4, 3
|
| 60 |
+
feetv = torch.zeros(feet.shape[:2], device=device) # 150, 4
|
| 61 |
+
feetv[:-1] = (feet[1:] - feet[:-1]).norm(dim=-1)
|
| 62 |
+
contacts = (feetv < 0.01).to(local_q) # cast to right dtype # b, 150, 4
|
| 63 |
+
|
| 64 |
+
mofea319 = torch.cat([contacts, root_pos, local_q_rot6d], dim=1)
|
| 65 |
+
assert mofea319.shape[1] == 319
|
| 66 |
+
mofea319 = mofea319.detach().cpu().numpy()
|
| 67 |
+
np.save(os.path.join(outputs_dir, name+'.npy'), mofea319)
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
motion_feats_extract("data/finedance/motion", "data/finedance/motion_fea319")
|
data/code/pre_music.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import wave
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import librosa as lr
|
| 7 |
+
|
| 8 |
+
FPS = 30 #* 5
|
| 9 |
+
HOP_LENGTH = 512
|
| 10 |
+
SR = FPS * HOP_LENGTH
|
| 11 |
+
EPS = 1e-6
|
| 12 |
+
|
| 13 |
+
# HOP_LENGTH = 160
|
| 14 |
+
# SR = 16000
|
| 15 |
+
|
| 16 |
+
audio_dir = 'data/finedance/music_wav'
|
| 17 |
+
# audio_dir = '/home/human/datasets/aist_plusplus_final/music'
|
| 18 |
+
# audio_dir = "/home/human/datasets/data/Clip/music_clip_rhythm"
|
| 19 |
+
|
| 20 |
+
target_dir_ori = "data/finedance/music_wav_test"
|
| 21 |
+
os.makedirs(target_dir_ori, exist_ok=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# AIST++
|
| 25 |
+
def _get_tempo(audio_name):
|
| 26 |
+
"""Get tempo (BPM) for a music by parsing music name."""
|
| 27 |
+
# a lot of stuff, only take the 5th element
|
| 28 |
+
audio_name = audio_name.split("_")[4]
|
| 29 |
+
assert len(audio_name) == 4
|
| 30 |
+
if audio_name[0:3] in [
|
| 31 |
+
"mBR",
|
| 32 |
+
"mPO",
|
| 33 |
+
"mLO",
|
| 34 |
+
"mMH",
|
| 35 |
+
"mLH",
|
| 36 |
+
"mWA",
|
| 37 |
+
"mKR",
|
| 38 |
+
"mJS",
|
| 39 |
+
"mJB",
|
| 40 |
+
]:
|
| 41 |
+
return int(audio_name[3]) * 10 + 80
|
| 42 |
+
elif audio_name[0:3] == "mHO":
|
| 43 |
+
return int(audio_name[3]) * 5 + 110
|
| 44 |
+
else:
|
| 45 |
+
assert False, audio_name
|
| 46 |
+
|
| 47 |
+
for file in tqdm(os.listdir(audio_dir)):
|
| 48 |
+
audio_name = file[:-4]
|
| 49 |
+
|
| 50 |
+
save_path = os.path.join(target_dir_ori, f"{audio_name}.npy") ##存特征路径
|
| 51 |
+
music_file = os.path.join(audio_dir, file)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
data, _ = librosa.load(music_file, sr=SR)
|
| 55 |
+
|
| 56 |
+
envelope = librosa.onset.onset_strength(y=data, sr=SR) # (seq_len,)
|
| 57 |
+
mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T # (seq_len, 20)
|
| 58 |
+
chroma = librosa.feature.chroma_cens(
|
| 59 |
+
y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12
|
| 60 |
+
).T # (seq_len, 12)
|
| 61 |
+
|
| 62 |
+
peak_idxs = librosa.onset.onset_detect(
|
| 63 |
+
onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
|
| 64 |
+
)
|
| 65 |
+
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 66 |
+
peak_onehot[peak_idxs] = 1.0 # (seq_len,)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
start_bpm = _get_tempo(audio_name)
|
| 70 |
+
except:
|
| 71 |
+
# determine manually
|
| 72 |
+
start_bpm = lr.beat.tempo(y=lr.load(music_file)[0])[0]
|
| 73 |
+
|
| 74 |
+
tempo, beat_idxs = librosa.beat.beat_track(
|
| 75 |
+
onset_envelope=envelope,
|
| 76 |
+
sr=SR,
|
| 77 |
+
hop_length=HOP_LENGTH,
|
| 78 |
+
start_bpm=start_bpm,
|
| 79 |
+
tightness=100,
|
| 80 |
+
)
|
| 81 |
+
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 82 |
+
beat_onehot[beat_idxs] = 1.0 # (seq_len,)
|
| 83 |
+
|
| 84 |
+
audio_feature = np.concatenate(
|
| 85 |
+
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
| 86 |
+
axis=-1,
|
| 87 |
+
)
|
| 88 |
+
np.save(save_path, audio_feature)
|
| 89 |
+
|
| 90 |
+
|
data/code/slice_music_motion.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
music_dir = "data/finedance/music_npy"
|
| 6 |
+
motion_dir = "data/finedance/motion_fea319"
|
| 7 |
+
|
| 8 |
+
music_out = "data/finedance/div_by_time/music_npy_"
|
| 9 |
+
motion_out = "data/finedance/div_by_time/motion_fea319_"
|
| 10 |
+
|
| 11 |
+
timelen = 120
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
music_out = music_out + str(timelen)
|
| 15 |
+
motion_out = motion_out + str(timelen)
|
| 16 |
+
if not os.path.exists(music_out):
|
| 17 |
+
os.makedirs(music_out)
|
| 18 |
+
if not os.path.exists(motion_out):
|
| 19 |
+
os.makedirs(motion_out)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
for file in os.listdir(motion_dir):
|
| 23 |
+
if file[-3:] != 'npy':
|
| 24 |
+
print(file[-3:])
|
| 25 |
+
continue
|
| 26 |
+
name = file.split(".")[0]
|
| 27 |
+
music_fea = np.load(os.path.join(music_dir, file))
|
| 28 |
+
motion_fea = np.load(os.path.join(motion_dir, file))
|
| 29 |
+
max_length = min(music_fea.shape[0], motion_fea.shape[0])
|
| 30 |
+
|
| 31 |
+
iters = (max_length//timelen)
|
| 32 |
+
max_length = iters*timelen
|
| 33 |
+
music_fea = music_fea[:max_length, :]
|
| 34 |
+
motion_fea = motion_fea[:max_length, :]
|
| 35 |
+
|
| 36 |
+
for i in range(iters):
|
| 37 |
+
music_clip = music_fea[i*timelen: (i+1)*timelen, :]
|
| 38 |
+
motion_clip = motion_fea[i*timelen: (i+1)*timelen, :]
|
| 39 |
+
np.save(os.path.join(music_out, name + "z@" + str(i).zfill(3) + ".npy"), music_clip)
|
| 40 |
+
np.save(os.path.join(motion_out, name + "z@" + str(i).zfill(3) + ".npy"), motion_clip)
|
| 41 |
+
|
dataset/FineDance_dataset.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils import data
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.insert(0,'.')
|
| 10 |
+
|
| 11 |
+
SMPL_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22]
|
| 12 |
+
|
| 13 |
+
SMPLX_JOINTS_FLIP_PERM = [0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13,
|
| 14 |
+
15, 17, 16, 19, 18, 21, 20, 22, 24, 23,
|
| 15 |
+
40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
| 16 |
+
25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
|
| 17 |
+
SMPLX_POSE_FLIP_PERM = []
|
| 18 |
+
for i in SMPLX_JOINTS_FLIP_PERM:
|
| 19 |
+
SMPLX_POSE_FLIP_PERM.append(3*i)
|
| 20 |
+
SMPLX_POSE_FLIP_PERM.append(3*i+1)
|
| 21 |
+
SMPLX_POSE_FLIP_PERM.append(3*i+2)
|
| 22 |
+
|
| 23 |
+
def flip_pose(pose):
|
| 24 |
+
#Flip pose.The flipping is based on SMPLX parameters.
|
| 25 |
+
pose = pose[:,SMPLX_POSE_FLIP_PERM]
|
| 26 |
+
# we also negate the second and the third dimension of the axis-angle
|
| 27 |
+
pose[:,1::3] = -pose[:,1::3]
|
| 28 |
+
pose[:,2::3] = -pose[:,2::3]
|
| 29 |
+
return pose
|
| 30 |
+
|
| 31 |
+
def get_train_test_list(datasplit):
|
| 32 |
+
all_list = []
|
| 33 |
+
train_list = []
|
| 34 |
+
for i in range(1,212):
|
| 35 |
+
all_list.append(str(i).zfill(3))
|
| 36 |
+
|
| 37 |
+
if datasplit == "cross_genre":
|
| 38 |
+
test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
|
| 39 |
+
ignor_list = ["116", "117", "118", "119", "120", "121", "122", "123", "202"]+["130"]
|
| 40 |
+
elif datasplit == "cross_dancer":
|
| 41 |
+
test_list = ['001','002','003','004','005','006','007','008','009','010','011','012','013','124','126','128','130','132']
|
| 42 |
+
ignor_list = ['115','117','119','121','122','135','137','139','141','143','145','147'] + ["116", "118", "120", "123", "202", "159"]+["130"] # 前一个列表为val set,后一个列表为ignore set
|
| 43 |
+
else:
|
| 44 |
+
raise("error of data split!")
|
| 45 |
+
for one in all_list:
|
| 46 |
+
if one not in test_list:
|
| 47 |
+
if one not in ignor_list:
|
| 48 |
+
train_list.append(one)
|
| 49 |
+
|
| 50 |
+
return train_list, test_list, ignor_list
|
| 51 |
+
|
| 52 |
+
class FineDance_Smpl(data.Dataset):
|
| 53 |
+
def __init__(self, args, istrain):
|
| 54 |
+
self.motion_dir = './data/finedance/motion_fea319'
|
| 55 |
+
self.music_dir = './data/finedance/music_npy'
|
| 56 |
+
self.istrain = istrain
|
| 57 |
+
self.seq_len = args.full_seq_len
|
| 58 |
+
slide = args.full_seq_len // args.windows
|
| 59 |
+
|
| 60 |
+
self.motion_index = []
|
| 61 |
+
self.music_index = []
|
| 62 |
+
self.name = []
|
| 63 |
+
motion_all = []
|
| 64 |
+
music_all = []
|
| 65 |
+
|
| 66 |
+
train_list, test_list, ignor_list = get_train_test_list(args.datasplit)
|
| 67 |
+
if self.istrain:
|
| 68 |
+
self.datalist= train_list
|
| 69 |
+
else:
|
| 70 |
+
self.datalist = test_list
|
| 71 |
+
|
| 72 |
+
total_length = 0 # 将数据集中的所有motion用同一个index索引
|
| 73 |
+
|
| 74 |
+
for name in tqdm(self.datalist):
|
| 75 |
+
save_name = name
|
| 76 |
+
name = name + ".npy"
|
| 77 |
+
|
| 78 |
+
if name[:-4] in ignor_list:
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
motion = np.load(os.path.join(self.motion_dir, name))
|
| 82 |
+
music = np.load(os.path.join(self.music_dir, name))
|
| 83 |
+
|
| 84 |
+
min_all_len = min(motion.shape[0], music.shape[0])
|
| 85 |
+
motion = motion[:min_all_len]
|
| 86 |
+
if motion.shape[-1] == 168:
|
| 87 |
+
motion = np.concatenate([motion[:,:69], motion[:,78:]], axis=1) # 22, 25
|
| 88 |
+
elif motion.shape[-1] == 319:
|
| 89 |
+
pass
|
| 90 |
+
elif motion.shape[-1] == 315:
|
| 91 |
+
pass
|
| 92 |
+
# motion = np.concatenate([motion[:,:135], motion[:,153:]], axis=1) #
|
| 93 |
+
else:
|
| 94 |
+
raise("input motion shape error! not 168 or 319!")
|
| 95 |
+
music = music[:min_all_len] # motion = motion[:min_all_len]
|
| 96 |
+
nums = (min_all_len-self.seq_len) // slide + 1 # 舍弃了最后一段不满seq_len的motion
|
| 97 |
+
|
| 98 |
+
if self.istrain:
|
| 99 |
+
clip_index = []
|
| 100 |
+
for i in range(nums):
|
| 101 |
+
motion_clip = motion[i * slide: i * slide + self.seq_len]
|
| 102 |
+
if motion_clip.std(axis=0).mean() > 0.07: # 判断是否为有效motion,如果耗费时间,可以考虑删掉
|
| 103 |
+
clip_index.append(i)
|
| 104 |
+
index = np.array(clip_index) * slide + total_length # clip_index为local index
|
| 105 |
+
index_ = np.array(clip_index) * slide
|
| 106 |
+
else:
|
| 107 |
+
index = np.arange(nums) * slide + total_length
|
| 108 |
+
index_ = np.arange(nums) * slide
|
| 109 |
+
|
| 110 |
+
motion_all.append(motion)
|
| 111 |
+
music_all.append(music)
|
| 112 |
+
|
| 113 |
+
if args.mix:
|
| 114 |
+
motion_index = []
|
| 115 |
+
music_index = []
|
| 116 |
+
num = (len(index) - 1) // 8 + 1
|
| 117 |
+
for i in range(num):
|
| 118 |
+
motion_index_tmp, music_index_tmp = np.meshgrid(index[i*8:(i+1)*8], index[i*8:(i+1)*8]) # 这里i有问题?似乎没有
|
| 119 |
+
motion_index += motion_index_tmp.reshape((-1)).tolist()
|
| 120 |
+
music_index += music_index_tmp.reshape((-1)).tolist()
|
| 121 |
+
index_tmp = np.meshgrid(index_[i*8:(i+1)*8])
|
| 122 |
+
index_ += index_tmp.reshape((-1)).tolist()
|
| 123 |
+
else:
|
| 124 |
+
motion_index = index.tolist()
|
| 125 |
+
music_index = index.tolist()
|
| 126 |
+
index_ = index_.tolist()
|
| 127 |
+
index_ = [save_name + "_" + str(element).zfill(5) for element in index_]
|
| 128 |
+
|
| 129 |
+
self.motion_index += motion_index
|
| 130 |
+
self.music_index += music_index
|
| 131 |
+
total_length += min_all_len
|
| 132 |
+
self.name += index_
|
| 133 |
+
|
| 134 |
+
self.motion = np.concatenate(motion_all, axis=0).astype(np.float32)
|
| 135 |
+
self.music = np.concatenate(music_all, axis=0).astype(np.float32)
|
| 136 |
+
|
| 137 |
+
self.len = len(self.motion_index)
|
| 138 |
+
print(f'FineDance has {self.len} samples..')
|
| 139 |
+
|
| 140 |
+
def __len__(self):
|
| 141 |
+
return self.len
|
| 142 |
+
|
| 143 |
+
def __getitem__(self, index):
|
| 144 |
+
motion_index = self.motion_index[index]
|
| 145 |
+
music_index = self.music_index[index]
|
| 146 |
+
motion = self.motion[motion_index:motion_index+self.seq_len]
|
| 147 |
+
if motion.shape[-1] == 319 or motion.shape[-1] == 139:
|
| 148 |
+
motion[:, 4:7] = motion[:, 4:7] - motion[:1, 4:7] # The first 4 dimension are foot contact
|
| 149 |
+
else:
|
| 150 |
+
motion[:, :3] = motion[:, :3] - motion[:1, :3]
|
| 151 |
+
music = self.music[music_index:music_index+self.seq_len]
|
| 152 |
+
filename = self.name[index]
|
| 153 |
+
# if np.random.rand(1) > 0.5:
|
| 154 |
+
# motion = motion[:,self.mirror_idx]
|
| 155 |
+
return motion, music, filename
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == '__main__':
|
| 159 |
+
data_split = {}
|
| 160 |
+
all_list = []
|
| 161 |
+
train_list = []
|
| 162 |
+
for i in range(1,212):
|
| 163 |
+
all_list.append(str(i).zfill(3))
|
| 164 |
+
test_list = ["001","002","003","004","005","006","007","008","009","010","011","012","013","124","126","128","130","132"]
|
| 165 |
+
val_list = ["115","117","119","121","122","135","137","139","141","143","145","147"]
|
| 166 |
+
for one in all_list:
|
| 167 |
+
if one not in test_list:
|
| 168 |
+
if one not in val_list:
|
| 169 |
+
train_list.append(one)
|
| 170 |
+
|
| 171 |
+
data_split["train"] = train_list
|
| 172 |
+
data_split["test"] = test_list
|
| 173 |
+
data_split["val"] = val_list
|
| 174 |
+
data_split["ignore"] = ["116", "117", "118", "119", "120", "121", "122", "123", "202"]
|
| 175 |
+
|
| 176 |
+
with open("data_crossdancer.json", "w") as f:
|
| 177 |
+
json.dump(data_split,f)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
print(train_list)
|
dataset/__init__.py
ADDED
|
File without changes
|
dataset/preprocess.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .scaler import MinMaxScaler
|
| 9 |
+
import pickle
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
| 13 |
+
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
| 14 |
+
path = Path(path) # os-agnostic
|
| 15 |
+
if path.exists() and not exist_ok:
|
| 16 |
+
suffix = path.suffix
|
| 17 |
+
path = path.with_suffix("")
|
| 18 |
+
dirs = glob.glob(f"{path}{sep}*") # similar paths
|
| 19 |
+
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
|
| 20 |
+
i = [int(m.groups()[0]) for m in matches if m] # indices
|
| 21 |
+
n = max(i) + 1 if i else 2 # increment number
|
| 22 |
+
path = Path(f"{path}{sep}{n}{suffix}") # update path
|
| 23 |
+
dir = path if path.suffix == "" else path.parent # directory
|
| 24 |
+
if not dir.exists() and mkdir:
|
| 25 |
+
dir.mkdir(parents=True, exist_ok=True) # make directory
|
| 26 |
+
return path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Normalizer:
|
| 30 |
+
def __init__(self, data):
|
| 31 |
+
flat = data.reshape(-1, data.shape[-1]) # bxt , 151
|
| 32 |
+
self.scaler = MinMaxScaler((-1, 1), clip=True)
|
| 33 |
+
self.scaler.fit(flat)
|
| 34 |
+
|
| 35 |
+
def normalize(self, x):
|
| 36 |
+
batch, seq, ch = x.shape
|
| 37 |
+
x = x.reshape(-1, ch)
|
| 38 |
+
return self.scaler.transform(x).reshape((batch, seq, ch))
|
| 39 |
+
|
| 40 |
+
def unnormalize(self, x):
|
| 41 |
+
batch, seq, ch = x.shape
|
| 42 |
+
x = x.reshape(-1, ch)
|
| 43 |
+
x = torch.clip(x, -1, 1) # clip to force compatibility
|
| 44 |
+
return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class My_Normalizer:
|
| 48 |
+
def __init__(self, data):
|
| 49 |
+
if isinstance(data, str):
|
| 50 |
+
self.scaler = MinMaxScaler((-1, 1), clip=True)
|
| 51 |
+
with open(data, 'rb') as f:
|
| 52 |
+
normalizer_state_dict = pickle.load(f)
|
| 53 |
+
# normalizer_state_dict = torch.load(data)
|
| 54 |
+
self.scaler.scale_ = normalizer_state_dict["scale"]
|
| 55 |
+
self.scaler.min_ = normalizer_state_dict["min"]
|
| 56 |
+
else:
|
| 57 |
+
flat = data.reshape(-1, data.shape[-1]) # bxt , 151
|
| 58 |
+
self.scaler = MinMaxScaler((-1, 1), clip=True)
|
| 59 |
+
self.scaler.fit(flat)
|
| 60 |
+
|
| 61 |
+
def normalize(self, x):
|
| 62 |
+
if len(x.shape) == 3:
|
| 63 |
+
batch, seq, ch = x.shape
|
| 64 |
+
x = x.reshape(-1, ch)
|
| 65 |
+
return self.scaler.transform(x).reshape((batch, seq, ch))
|
| 66 |
+
elif len(x.shape) == 2:
|
| 67 |
+
batch, ch = x.shape
|
| 68 |
+
return self.scaler.transform(x)
|
| 69 |
+
else:
|
| 70 |
+
raise("input error!")
|
| 71 |
+
|
| 72 |
+
def unnormalize(self, x):
|
| 73 |
+
if len(x.shape) == 3:
|
| 74 |
+
batch, seq, ch = x.shape
|
| 75 |
+
x = x.reshape(-1, ch)
|
| 76 |
+
x = torch.clip(x, -1, 1) # clip to force compatibility
|
| 77 |
+
return self.scaler.inverse_transform(x).reshape((batch, seq, ch))
|
| 78 |
+
elif len(x.shape) == 2:
|
| 79 |
+
x = torch.clip(x, -1, 1)
|
| 80 |
+
return self.scaler.inverse_transform(x)
|
| 81 |
+
else:
|
| 82 |
+
raise("input error!")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def vectorize_many(data):
|
| 86 |
+
# given a list of batch x seqlen x joints? x channels, flatten all to batch x seqlen x -1, concatenate
|
| 87 |
+
batch_size = data[0].shape[0]
|
| 88 |
+
seq_len = data[0].shape[1]
|
| 89 |
+
|
| 90 |
+
out = [x.reshape(batch_size, seq_len, -1).contiguous() for x in data]
|
| 91 |
+
|
| 92 |
+
global_pose_vec_gt = torch.cat(out, dim=2)
|
| 93 |
+
return global_pose_vec_gt
|
dataset/quaternion.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle,
|
| 3 |
+
matrix_to_quaternion, matrix_to_rotation_6d,
|
| 4 |
+
quaternion_to_matrix, rotation_6d_to_matrix)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def quat_to_6v(q):
|
| 8 |
+
assert q.shape[-1] == 4
|
| 9 |
+
mat = quaternion_to_matrix(q)
|
| 10 |
+
mat = matrix_to_rotation_6d(mat)
|
| 11 |
+
return mat
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def quat_from_6v(q):
|
| 15 |
+
assert q.shape[-1] == 6
|
| 16 |
+
mat = rotation_6d_to_matrix(q)
|
| 17 |
+
quat = matrix_to_quaternion(mat)
|
| 18 |
+
return quat
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def ax_to_6v(q):
|
| 22 |
+
assert q.shape[-1] == 3
|
| 23 |
+
mat = axis_angle_to_matrix(q)
|
| 24 |
+
mat = matrix_to_rotation_6d(mat)
|
| 25 |
+
return mat
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def ax_from_6v(q):
|
| 29 |
+
assert q.shape[-1] == 6
|
| 30 |
+
mat = rotation_6d_to_matrix(q)
|
| 31 |
+
ax = matrix_to_axis_angle(mat)
|
| 32 |
+
return ax
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def quat_slerp(x, y, a):
|
| 36 |
+
"""
|
| 37 |
+
Performs spherical linear interpolation (SLERP) between x and y, with proportion a
|
| 38 |
+
|
| 39 |
+
:param x: quaternion tensor (N, S, J, 4)
|
| 40 |
+
:param y: quaternion tensor (N, S, J, 4)
|
| 41 |
+
:param a: interpolation weight (S, )
|
| 42 |
+
:return: tensor of interpolation results
|
| 43 |
+
"""
|
| 44 |
+
len = torch.sum(x * y, axis=-1)
|
| 45 |
+
|
| 46 |
+
neg = len < 0.0
|
| 47 |
+
len[neg] = -len[neg]
|
| 48 |
+
y[neg] = -y[neg]
|
| 49 |
+
|
| 50 |
+
a = torch.zeros_like(x[..., 0]) + a
|
| 51 |
+
|
| 52 |
+
amount0 = torch.zeros_like(a)
|
| 53 |
+
amount1 = torch.zeros_like(a)
|
| 54 |
+
|
| 55 |
+
linear = (1.0 - len) < 0.01
|
| 56 |
+
omegas = torch.arccos(len[~linear])
|
| 57 |
+
sinoms = torch.sin(omegas)
|
| 58 |
+
|
| 59 |
+
amount0[linear] = 1.0 - a[linear]
|
| 60 |
+
amount0[~linear] = torch.sin((1.0 - a[~linear]) * omegas) / sinoms
|
| 61 |
+
|
| 62 |
+
amount1[linear] = a[linear]
|
| 63 |
+
amount1[~linear] = torch.sin(a[~linear] * omegas) / sinoms
|
| 64 |
+
|
| 65 |
+
# reshape
|
| 66 |
+
amount0 = amount0[..., None]
|
| 67 |
+
amount1 = amount1[..., None]
|
| 68 |
+
|
| 69 |
+
res = amount0 * x + amount1 * y
|
| 70 |
+
|
| 71 |
+
return res
|
dataset/scaler.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _handle_zeros_in_scale(scale, copy=True, constant_mask=None):
|
| 5 |
+
# if we are fitting on 1D arrays, scale might be a scalar
|
| 6 |
+
if constant_mask is None:
|
| 7 |
+
# Detect near constant values to avoid dividing by a very small
|
| 8 |
+
# value that could lead to surprising results and numerical
|
| 9 |
+
# stability issues.
|
| 10 |
+
constant_mask = scale < 10 * torch.finfo(scale.dtype).eps
|
| 11 |
+
|
| 12 |
+
if copy:
|
| 13 |
+
# New array to avoid side-effects
|
| 14 |
+
scale = scale.clone()
|
| 15 |
+
scale[constant_mask] = 1.0
|
| 16 |
+
return scale
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MinMaxScaler:
|
| 20 |
+
_parameter_constraints: dict = {
|
| 21 |
+
"feature_range": [tuple],
|
| 22 |
+
"copy": ["boolean"],
|
| 23 |
+
"clip": ["boolean"],
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def __init__(self, feature_range=(0, 1), *, copy=True, clip=False):
|
| 27 |
+
self.feature_range = feature_range
|
| 28 |
+
self.copy = copy
|
| 29 |
+
self.clip = clip
|
| 30 |
+
|
| 31 |
+
def _reset(self):
|
| 32 |
+
"""Reset internal data-dependent state of the scaler, if necessary.
|
| 33 |
+
__init__ parameters are not touched.
|
| 34 |
+
"""
|
| 35 |
+
# Checking one attribute is enough, because they are all set together
|
| 36 |
+
# in partial_fit
|
| 37 |
+
if hasattr(self, "scale_"):
|
| 38 |
+
del self.scale_
|
| 39 |
+
del self.min_
|
| 40 |
+
del self.n_samples_seen_
|
| 41 |
+
del self.data_min_
|
| 42 |
+
del self.data_max_
|
| 43 |
+
del self.data_range_
|
| 44 |
+
|
| 45 |
+
def fit(self, X):
|
| 46 |
+
# Reset internal state before fitting
|
| 47 |
+
self._reset()
|
| 48 |
+
return self.partial_fit(X)
|
| 49 |
+
|
| 50 |
+
def partial_fit(self, X):
|
| 51 |
+
feature_range = self.feature_range
|
| 52 |
+
if feature_range[0] >= feature_range[1]:
|
| 53 |
+
raise ValueError(
|
| 54 |
+
"Minimum of desired feature range must be smaller than maximum. Got %s."
|
| 55 |
+
% str(feature_range)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
data_min = torch.min(X, axis=0)[0]
|
| 59 |
+
data_max = torch.max(X, axis=0)[0]
|
| 60 |
+
|
| 61 |
+
self.n_samples_seen_ = X.shape[0]
|
| 62 |
+
|
| 63 |
+
data_range = data_max - data_min
|
| 64 |
+
self.scale_ = (feature_range[1] - feature_range[0]) / _handle_zeros_in_scale(
|
| 65 |
+
data_range, copy=True
|
| 66 |
+
)
|
| 67 |
+
self.min_ = feature_range[0] - data_min * self.scale_
|
| 68 |
+
self.data_min_ = data_min
|
| 69 |
+
self.data_max_ = data_max
|
| 70 |
+
self.data_range_ = data_range
|
| 71 |
+
return self
|
| 72 |
+
|
| 73 |
+
def transform(self, X):
|
| 74 |
+
X *= self.scale_.to(X.device)
|
| 75 |
+
X += self.min_.to(X.device)
|
| 76 |
+
if self.clip:
|
| 77 |
+
torch.clip(X, self.feature_range[0], self.feature_range[1], out=X)
|
| 78 |
+
return X
|
| 79 |
+
|
| 80 |
+
def inverse_transform(self, X):
|
| 81 |
+
X -= self.min_[-X.shape[1] :].to(X.device)
|
| 82 |
+
X /= self.scale_[-X.shape[1] :].to(X.device)
|
| 83 |
+
return X
|
environment.yaml
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: FineNet
|
| 2 |
+
channels:
|
| 3 |
+
- anaconda
|
| 4 |
+
- pytorch
|
| 5 |
+
- conda-forge
|
| 6 |
+
- https://repo.anaconda.com/pkgs/main
|
| 7 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/peterjc123/
|
| 8 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
|
| 9 |
+
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
|
| 10 |
+
- defaults
|
| 11 |
+
dependencies:
|
| 12 |
+
- _libgcc_mutex=0.1=main
|
| 13 |
+
- _openmp_mutex=5.1=1_gnu
|
| 14 |
+
- aiohttp=3.8.1=py38h0a891b7_1
|
| 15 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
| 16 |
+
- asttokens=2.2.1=pyhd8ed1ab_0
|
| 17 |
+
- async-timeout=4.0.3=pyhd8ed1ab_0
|
| 18 |
+
- attrs=23.1.0=pyh71513ae_1
|
| 19 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
| 20 |
+
- backports=1.0=pyhd8ed1ab_3
|
| 21 |
+
- backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
|
| 22 |
+
- blas=1.0=mkl
|
| 23 |
+
- blinker=1.7.0=pyhd8ed1ab_0
|
| 24 |
+
- brotlipy=0.7.0=py38h0a891b7_1004
|
| 25 |
+
- bzip2=1.0.8=h7f98852_4
|
| 26 |
+
- c-ares=1.18.1=h7f98852_0
|
| 27 |
+
- ca-certificates=2023.11.17=hbcca054_0
|
| 28 |
+
- certifi=2023.11.17=pyhd8ed1ab_0
|
| 29 |
+
- cffi=1.15.0=py38h3931269_0
|
| 30 |
+
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
| 31 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
| 32 |
+
- cpuonly=1.0=0
|
| 33 |
+
- cryptography=37.0.2=py38h2b5fc30_0
|
| 34 |
+
- cudatoolkit=11.3.1=h9edb442_10
|
| 35 |
+
- cycler=0.11.0=pyhd8ed1ab_0
|
| 36 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
| 37 |
+
- debugpy=1.5.1=py38h295c915_0
|
| 38 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
| 39 |
+
- executing=1.2.0=pyhd8ed1ab_0
|
| 40 |
+
- ffmpeg=4.3=hf484d3e_0
|
| 41 |
+
- freetype=2.10.4=h0708190_1
|
| 42 |
+
- frozenlist=1.3.0=py38h0a891b7_1
|
| 43 |
+
- fsspec=2023.5.0=pyh1a96a4e_0
|
| 44 |
+
- future=0.18.3=pyhd8ed1ab_0
|
| 45 |
+
- geos=3.10.2=h9c3ff4c_0
|
| 46 |
+
- giflib=5.2.1=h5eee18b_1
|
| 47 |
+
- gmp=6.2.1=h58526e2_0
|
| 48 |
+
- gnutls=3.6.13=h85f3911_1
|
| 49 |
+
- icu=67.1=he1b5a44_0
|
| 50 |
+
- idna=3.4=pyhd8ed1ab_0
|
| 51 |
+
- importlib-metadata=6.8.0=pyha770c72_0
|
| 52 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
| 53 |
+
- ipykernel=6.14.0=py38h7f3c49e_0
|
| 54 |
+
- ipython=8.4.0=py38h578d9bd_0
|
| 55 |
+
- jedi=0.18.2=pyhd8ed1ab_0
|
| 56 |
+
- jpeg=9e=h166bdaf_1
|
| 57 |
+
- jupyter_client=7.0.6=pyhd8ed1ab_0
|
| 58 |
+
- jupyter_core=4.12.0=py38h578d9bd_0
|
| 59 |
+
- kiwisolver=1.4.4=py38h6a678d5_0
|
| 60 |
+
- lame=3.100=h7f98852_1001
|
| 61 |
+
- lcms2=2.12=hddcbb42_0
|
| 62 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 63 |
+
- libffi=3.4.2=h6a678d5_6
|
| 64 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 65 |
+
- libgfortran-ng=7.5.0=ha8ba4b0_17
|
| 66 |
+
- libgfortran4=7.5.0=ha8ba4b0_17
|
| 67 |
+
- libgomp=11.2.0=h1234567_1
|
| 68 |
+
- libiconv=1.17=h166bdaf_0
|
| 69 |
+
- libpng=1.6.37=h21135ba_2
|
| 70 |
+
- libprotobuf=3.18.0=h780b84a_1
|
| 71 |
+
- libsodium=1.0.18=h36c2ea0_1
|
| 72 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 73 |
+
- libtiff=4.2.0=hf544144_3
|
| 74 |
+
- libuv=1.43.0=h7f98852_0
|
| 75 |
+
- libwebp=1.2.2=h55f646e_0
|
| 76 |
+
- libwebp-base=1.2.2=h7f98852_1
|
| 77 |
+
- lightning-utilities=0.8.0=pyhd8ed1ab_0
|
| 78 |
+
- lz4-c=1.9.3=h9c3ff4c_1
|
| 79 |
+
- mapbox_earcut=1.0.0=py38h43d8883_3
|
| 80 |
+
- matplotlib-base=3.2.2=py38h5d868c9_1
|
| 81 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
| 82 |
+
- mkl=2021.4.0=h06a4308_640
|
| 83 |
+
- mkl-service=2.4.0=py38h95df7f1_0
|
| 84 |
+
- mkl_fft=1.3.1=py38h8666266_1
|
| 85 |
+
- mkl_random=1.2.2=py38h1abd341_0
|
| 86 |
+
- mpi=1.0=mpich
|
| 87 |
+
- mpi4py=3.1.4=py38hfc96bbd_0
|
| 88 |
+
- mpich=3.3.2=hc856adb_0
|
| 89 |
+
- multidict=6.0.2=py38h0a891b7_1
|
| 90 |
+
- ncurses=6.3=h5eee18b_3
|
| 91 |
+
- nest-asyncio=1.5.6=pyhd8ed1ab_0
|
| 92 |
+
- nettle=3.6=he412f7d_0
|
| 93 |
+
- networkx=3.1=pyhd8ed1ab_0
|
| 94 |
+
- ninja=1.11.0=h924138e_0
|
| 95 |
+
- oauthlib=3.2.2=pyhd8ed1ab_0
|
| 96 |
+
- olefile=0.46=pyh9f0ad1d_1
|
| 97 |
+
- openh264=2.1.1=h780b84a_0
|
| 98 |
+
- openjpeg=2.4.0=hb52868f_1
|
| 99 |
+
- openssl=1.1.1w=h7f8727e_0
|
| 100 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
| 101 |
+
- pexpect=4.8.0=pyh1a96a4e_2
|
| 102 |
+
- pickleshare=0.7.5=py_1003
|
| 103 |
+
- prompt-toolkit=3.0.39=pyha770c72_0
|
| 104 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
| 105 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
| 106 |
+
- pyasn1=0.5.0=pyhd8ed1ab_0
|
| 107 |
+
- pyasn1-modules=0.3.0=pyhd8ed1ab_0
|
| 108 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
| 109 |
+
- pyjwt=2.8.0=pyhd8ed1ab_0
|
| 110 |
+
- pyopenssl=22.0.0=pyhd8ed1ab_1
|
| 111 |
+
- pyrender=0.1.45=pyh8a188c0_3
|
| 112 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
| 113 |
+
- python=3.8.15=h7a1cb2a_2
|
| 114 |
+
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
| 115 |
+
- python_abi=3.8=2_cp38
|
| 116 |
+
- pytorch-lightning=1.5.8=pyhd8ed1ab_0
|
| 117 |
+
- pytorch-mutex=1.0=cuda
|
| 118 |
+
- pyu2f=0.1.5=pyhd8ed1ab_0
|
| 119 |
+
- pyyaml=6.0=py38h0a891b7_4
|
| 120 |
+
- pyzmq=19.0.2=py38ha71036d_2
|
| 121 |
+
- readline=8.2=h5eee18b_0
|
| 122 |
+
- requests=2.28.2=pyhd8ed1ab_0
|
| 123 |
+
- requests-oauthlib=1.3.1=pyhd8ed1ab_0
|
| 124 |
+
- rsa=4.9=pyhd8ed1ab_0
|
| 125 |
+
- six=1.16.0=pyh6c4a22f_0
|
| 126 |
+
- sqlite=3.40.0=h5082296_0
|
| 127 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
| 128 |
+
- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
|
| 129 |
+
- tk=8.6.12=h1ccaba5_0
|
| 130 |
+
- tornado=6.1=py38h0a891b7_3
|
| 131 |
+
- tqdm=4.65.0=pyhd8ed1ab_1
|
| 132 |
+
- traitlets=5.9.0=pyhd8ed1ab_0
|
| 133 |
+
- typing-extensions=4.4.0=hd8ed1ab_0
|
| 134 |
+
- typing_extensions=4.4.0=pyha770c72_0
|
| 135 |
+
- urllib3=1.26.14=pyhd8ed1ab_0
|
| 136 |
+
- wcwidth=0.2.6=pyhd8ed1ab_0
|
| 137 |
+
- xz=5.2.8=h5eee18b_0
|
| 138 |
+
- yaml=0.2.5=h7f98852_2
|
| 139 |
+
- yarl=1.7.2=py38h0a891b7_2
|
| 140 |
+
- zeromq=4.3.4=h9c3ff4c_1
|
| 141 |
+
- zlib=1.2.13=h5eee18b_0
|
| 142 |
+
- zstd=1.5.0=ha95c52a_0
|
| 143 |
+
- pip:
|
| 144 |
+
- absl-py==1.4.0
|
| 145 |
+
- accelerate==0.19.0
|
| 146 |
+
- alembic==1.12.0
|
| 147 |
+
- aniposelib==0.4.3
|
| 148 |
+
- antlr4-python3-runtime==4.8
|
| 149 |
+
- appdirs==1.4.4
|
| 150 |
+
- audioread==3.0.0
|
| 151 |
+
- autopage==0.5.1
|
| 152 |
+
- backoff==2.2.1
|
| 153 |
+
- beautifulsoup4==4.12.2
|
| 154 |
+
- bertopic==0.15.0
|
| 155 |
+
- blobfile==2.0.2
|
| 156 |
+
- boilerpy3==1.0.6
|
| 157 |
+
- cachetools==5.3.1
|
| 158 |
+
- canals==0.2.2
|
| 159 |
+
- cattrs==23.1.2
|
| 160 |
+
- chumpy==0.69
|
| 161 |
+
- click==8.1.3
|
| 162 |
+
- cliff==4.3.0
|
| 163 |
+
- clip==1.0
|
| 164 |
+
- cmaes==0.10.0
|
| 165 |
+
- cmd2==2.4.3
|
| 166 |
+
- colorlog==6.7.0
|
| 167 |
+
- commonmark==0.9.1
|
| 168 |
+
- configer==1.3.1
|
| 169 |
+
- configparser==5.3.0
|
| 170 |
+
- contourpy==1.0.7
|
| 171 |
+
- coremltools==6.1
|
| 172 |
+
- cython==0.29.35
|
| 173 |
+
- decorator==4.4.2
|
| 174 |
+
- diffusers==0.16.1
|
| 175 |
+
- dill==0.3.6
|
| 176 |
+
- docker-pycreds==0.4.0
|
| 177 |
+
- docopt==0.6.2
|
| 178 |
+
- easydict==1.7
|
| 179 |
+
- einops==0.6.1
|
| 180 |
+
- etils==0.9.0
|
| 181 |
+
- events==0.4
|
| 182 |
+
- exceptiongroup==1.1.2
|
| 183 |
+
- farm-haystack==1.18.1
|
| 184 |
+
- filelock==3.12.0
|
| 185 |
+
- fire==0.1.3
|
| 186 |
+
- fonttools==4.39.4
|
| 187 |
+
- freetype-py==2.3.0
|
| 188 |
+
- ftfy==6.1.1
|
| 189 |
+
- fvcore==0.1.5.post20221221
|
| 190 |
+
- gdown==4.7.1
|
| 191 |
+
- gitdb==4.0.10
|
| 192 |
+
- gitpython==3.1.31
|
| 193 |
+
- google-auth==2.22.0
|
| 194 |
+
- google-auth-oauthlib==1.0.0
|
| 195 |
+
- googleapis-common-protos==1.57.0
|
| 196 |
+
- greenlet==2.0.2
|
| 197 |
+
- grpcio==1.56.0
|
| 198 |
+
- h5py==3.9.0
|
| 199 |
+
- hdbscan==0.8.33
|
| 200 |
+
- huggingface-hub==0.14.1
|
| 201 |
+
- hydra==2.5
|
| 202 |
+
- hydra-colorlog==1.1.0.dev1
|
| 203 |
+
- hydra-core==1.1.0rc1
|
| 204 |
+
- hydra-optuna-sweeper==1.1.0.dev2
|
| 205 |
+
- imageio==2.27.0
|
| 206 |
+
- imageio-ffmpeg==0.4.9
|
| 207 |
+
- importlib-resources==5.10.1
|
| 208 |
+
- inflect==7.0.0
|
| 209 |
+
- iopath==0.1.10
|
| 210 |
+
- joblib==1.2.0
|
| 211 |
+
- json-tricks==3.17.1
|
| 212 |
+
- jsonschema==4.18.4
|
| 213 |
+
- jsonschema-specifications==2023.7.1
|
| 214 |
+
- jukebox==1.0
|
| 215 |
+
- lazy-imports==0.3.1
|
| 216 |
+
- lazy-loader==0.2
|
| 217 |
+
- librosa==0.7.2
|
| 218 |
+
- llvmlite==0.31.0
|
| 219 |
+
- lxml==4.9.2
|
| 220 |
+
- mako==1.2.4
|
| 221 |
+
- markdown==3.4.3
|
| 222 |
+
- markupsafe==2.1.3
|
| 223 |
+
- matplotlib==3.7.3
|
| 224 |
+
- monotonic==1.6
|
| 225 |
+
- more-itertools==10.0.0
|
| 226 |
+
- moviepy==1.0.3
|
| 227 |
+
- mpmath==1.2.1
|
| 228 |
+
- msgpack==1.0.5
|
| 229 |
+
- multiprocess==0.70.14
|
| 230 |
+
- netifaces==0.11.0
|
| 231 |
+
- nltk==3.8.1
|
| 232 |
+
- num2words==0.5.12
|
| 233 |
+
- numba==0.48.0
|
| 234 |
+
- numpy==1.24.4
|
| 235 |
+
- omegaconf==2.1.0rc1
|
| 236 |
+
- onnx==1.12.0
|
| 237 |
+
- onnxoptimizer==0.3.2
|
| 238 |
+
- onnxsim==0.4.10
|
| 239 |
+
- opencv-contrib-python==4.8.0.74
|
| 240 |
+
- opencv-python==4.7.0.72
|
| 241 |
+
- optuna==2.4.0
|
| 242 |
+
- p-tqdm==1.4.0
|
| 243 |
+
- packaging==22.0
|
| 244 |
+
- pandas==1.2.4
|
| 245 |
+
- pathos==0.3.0
|
| 246 |
+
- pathtools==0.1.2
|
| 247 |
+
- pbr==5.11.1
|
| 248 |
+
- pickle5==0.0.11
|
| 249 |
+
- pillow==9.5.0
|
| 250 |
+
- pip==23.3.1
|
| 251 |
+
- pkgutil-resolve-name==1.3.10
|
| 252 |
+
- platformdirs==3.9.1
|
| 253 |
+
- plotly==5.17.0
|
| 254 |
+
- pooch==1.6.0
|
| 255 |
+
- portalocker==2.7.0
|
| 256 |
+
- posthog==3.0.1
|
| 257 |
+
- pox==0.3.2
|
| 258 |
+
- ppft==1.7.6.6
|
| 259 |
+
- prettytable==3.9.0
|
| 260 |
+
- proglog==0.1.10
|
| 261 |
+
- promise==2.3
|
| 262 |
+
- prompthub-py==4.0.0
|
| 263 |
+
- protobuf==3.20.1
|
| 264 |
+
- psutil==5.9.5
|
| 265 |
+
- publicip==1.0.1
|
| 266 |
+
- pycocotools==2.0.6
|
| 267 |
+
- pycryptodomex==3.17
|
| 268 |
+
- pydantic==1.10.11
|
| 269 |
+
- pydeprecate==0.3.2
|
| 270 |
+
- pydub==0.25.1
|
| 271 |
+
- pyglet==1.4.0b1
|
| 272 |
+
- pygments==2.13.0
|
| 273 |
+
- pynndescent==0.5.10
|
| 274 |
+
- pyopengl==3.1.0
|
| 275 |
+
- pyopengl-accelerate==3.1.7
|
| 276 |
+
- pyparsing==3.0.9
|
| 277 |
+
- pyperclip==1.8.2
|
| 278 |
+
- python-dotenv==0.17.1
|
| 279 |
+
- pytorch3d==0.3.0
|
| 280 |
+
- pytz==2022.6
|
| 281 |
+
- pywavelets==1.4.1
|
| 282 |
+
- quantulum3==0.9.0
|
| 283 |
+
- rank-bm25==0.2.2
|
| 284 |
+
- referencing==0.30.0
|
| 285 |
+
- regex==2023.5.5
|
| 286 |
+
- requests-cache==0.9.8
|
| 287 |
+
- resampy==0.3.1
|
| 288 |
+
- rich==12.6.0
|
| 289 |
+
- rpds-py==0.9.2
|
| 290 |
+
- safetensors==0.3.1
|
| 291 |
+
- scikit-image==0.18.0
|
| 292 |
+
- scikit-learn==1.2.2
|
| 293 |
+
- scipy==1.10.1
|
| 294 |
+
- sentence-transformers==2.2.2
|
| 295 |
+
- sentencepiece==0.1.99
|
| 296 |
+
- sentry-sdk==1.25.0
|
| 297 |
+
- setproctitle==1.3.2
|
| 298 |
+
- setuptools==68.2.2
|
| 299 |
+
- shapely==2.0.1
|
| 300 |
+
- smmap==5.0.0
|
| 301 |
+
- smplx==0.1.28
|
| 302 |
+
- soundfile==0.10.3.post1
|
| 303 |
+
- soupsieve==2.4.1
|
| 304 |
+
- soxr==0.3.5
|
| 305 |
+
- sqlalchemy==2.0.20
|
| 306 |
+
- sseclient-py==1.7.2
|
| 307 |
+
- stevedore==5.1.0
|
| 308 |
+
- sympy==1.11.1
|
| 309 |
+
- tabulate==0.9.0
|
| 310 |
+
- tbb==2021.10.0
|
| 311 |
+
- tenacity==8.2.2
|
| 312 |
+
- tensorboard==2.13.0
|
| 313 |
+
- tensorboard-data-server==0.7.1
|
| 314 |
+
- tensorboardx==1.6
|
| 315 |
+
- tensorflow-datasets==4.7.0
|
| 316 |
+
- tensorflow-metadata==1.12.0
|
| 317 |
+
- tf2onnx==1.13.0
|
| 318 |
+
- threadpoolctl==3.1.0
|
| 319 |
+
- tifffile==2023.7.4
|
| 320 |
+
- tiktoken==0.4.0
|
| 321 |
+
- timm==0.4.5
|
| 322 |
+
- tokenizers==0.13.3
|
| 323 |
+
- toml==0.10.2
|
| 324 |
+
- torch==1.13.1+cu116
|
| 325 |
+
- torch-tb-profiler==0.4.1
|
| 326 |
+
- torchaudio==0.13.1+cu116
|
| 327 |
+
- torchgeometry==0.1.2
|
| 328 |
+
- torchmetrics==0.7.0
|
| 329 |
+
- torchvision==0.14.1+cu116
|
| 330 |
+
- transformers==4.30.1
|
| 331 |
+
- transforms3d==0.4.1
|
| 332 |
+
- trimesh==3.9.24
|
| 333 |
+
- tzdata==2023.3
|
| 334 |
+
- umap-learn==0.5.4
|
| 335 |
+
- unidecode==1.1.1
|
| 336 |
+
- url-normalize==1.4.3
|
| 337 |
+
- wandb==0.15.2
|
| 338 |
+
- werkzeug==2.3.6
|
| 339 |
+
- wget==3.2
|
| 340 |
+
- wheel==0.41.2
|
| 341 |
+
- yacs==0.1.8
|
| 342 |
+
- zipp==3.16.2
|
| 343 |
+
prefix: /home/lrh/.conda/envs/py38
|
environment_macos.yaml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: FineNet
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.9
|
| 8 |
+
- numpy
|
| 9 |
+
- scipy
|
| 10 |
+
- matplotlib
|
| 11 |
+
- pandas
|
| 12 |
+
- pyyaml
|
| 13 |
+
- h5py
|
| 14 |
+
- tqdm
|
| 15 |
+
- ipython
|
| 16 |
+
- jupyter_client
|
| 17 |
+
- cython
|
| 18 |
+
- ffmpeg
|
| 19 |
+
- pip
|
| 20 |
+
- pip:
|
| 21 |
+
- torch
|
| 22 |
+
- torchaudio
|
| 23 |
+
- torchvision
|
| 24 |
+
- pytorch-lightning==1.9.5
|
| 25 |
+
- torchmetrics==0.11.4
|
| 26 |
+
- accelerate
|
| 27 |
+
- einops
|
| 28 |
+
- smplx
|
| 29 |
+
- trimesh
|
| 30 |
+
- pyrender
|
| 31 |
+
- opencv-python
|
| 32 |
+
- opencv-contrib-python
|
| 33 |
+
- scikit-learn
|
| 34 |
+
- scikit-image
|
| 35 |
+
- transformers
|
| 36 |
+
- diffusers
|
| 37 |
+
- librosa
|
| 38 |
+
- soundfile
|
| 39 |
+
- moviepy
|
| 40 |
+
- imageio
|
| 41 |
+
- imageio-ffmpeg
|
| 42 |
+
- hydra-core==1.3.2
|
| 43 |
+
- omegaconf==2.3.0
|
| 44 |
+
- wandb
|
| 45 |
+
- tensorboard
|
| 46 |
+
- tensorboardx
|
| 47 |
+
- easydict
|
| 48 |
+
- fire
|
| 49 |
+
- ftfy
|
| 50 |
+
- regex
|
| 51 |
+
- pillow
|
| 52 |
+
- plotly
|
| 53 |
+
- gdown
|
| 54 |
+
- huggingface-hub
|
| 55 |
+
- safetensors
|
| 56 |
+
- sentence-transformers
|
| 57 |
+
- pydub
|
| 58 |
+
- json-tricks
|
| 59 |
+
- yacs
|
| 60 |
+
- fvcore
|
| 61 |
+
- iopath
|
| 62 |
+
- tabulate
|
| 63 |
+
- rich
|
| 64 |
+
- click
|
generate_all.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os,sys
|
| 3 |
+
from functools import cmp_to_key
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# import jukemirlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from args import FineDance_parse_test_opt
|
| 12 |
+
from train_seq import EDGE
|
| 13 |
+
from dataset.FineDance_dataset import get_train_test_list
|
| 14 |
+
|
| 15 |
+
# test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
|
| 16 |
+
test_list = ["063", "144"]
|
| 17 |
+
|
| 18 |
+
music_dir = "data/finedance/div_by_time/music_npy_120"
|
| 19 |
+
count = 10
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test(opt):
|
| 23 |
+
# split = get_train_test_dict(opt.datasplit)
|
| 24 |
+
train_list, test_list, ignore_list = get_train_test_list(opt.datasplit)
|
| 25 |
+
for file in os.listdir(music_dir):
|
| 26 |
+
if file[:3] in ignore_list:
|
| 27 |
+
continue
|
| 28 |
+
if not file[:3] in test_list:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
file_name = file[:-4]
|
| 32 |
+
music_fea = np.load(os.path.join(music_dir, file))
|
| 33 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 34 |
+
music_fea = torch.from_numpy(music_fea).float().to(device).unsqueeze(0)
|
| 35 |
+
music_fea = music_fea.repeat(count, 1, 1)
|
| 36 |
+
all_filenames = [file_name]*count
|
| 37 |
+
|
| 38 |
+
# directory for optionally saving the dances for eval
|
| 39 |
+
fk_out = None
|
| 40 |
+
if opt.save_motions:
|
| 41 |
+
fk_out = opt.motion_save_dir
|
| 42 |
+
|
| 43 |
+
model = EDGE(opt, opt.feature_type, opt.checkpoint)
|
| 44 |
+
model.eval()
|
| 45 |
+
|
| 46 |
+
data_tuple = None, music_fea, all_filenames
|
| 47 |
+
model.render_sample(
|
| 48 |
+
data_tuple, "test", opt.render_dir, render_count=10, mode='normal', fk_out=fk_out, render=not opt.no_render
|
| 49 |
+
)
|
| 50 |
+
print("Done")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
opt = FineDance_parse_test_opt()
|
| 55 |
+
test(opt)
|
| 56 |
+
|
| 57 |
+
# python test.py --save_motions
|
generate_dance.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-end dance generation from a music file.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python generate_dance.py /path/to/music.mp3
|
| 6 |
+
python generate_dance.py /path/to/music.mp3 --output my_dance.mp4
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import glob
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import sys
|
| 14 |
+
from functools import cmp_to_key
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from tempfile import TemporaryDirectory
|
| 17 |
+
|
| 18 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 19 |
+
|
| 20 |
+
import librosa
|
| 21 |
+
import librosa as lr
|
| 22 |
+
import numpy as np
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
import torch
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from train_seq import EDGE
|
| 28 |
+
from render import MovieMaker, motion_data_load_process
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- Audio utilities (from test.py) ---
|
| 32 |
+
|
| 33 |
+
def slice_audio(audio_file, stride, length, out_dir):
|
| 34 |
+
audio, sr = lr.load(audio_file, sr=None)
|
| 35 |
+
file_name = os.path.splitext(os.path.basename(audio_file))[0]
|
| 36 |
+
start_idx = 0
|
| 37 |
+
idx = 0
|
| 38 |
+
window = int(length * sr)
|
| 39 |
+
stride_step = int(stride * sr)
|
| 40 |
+
while start_idx <= len(audio) - window:
|
| 41 |
+
audio_slice = audio[start_idx : start_idx + window]
|
| 42 |
+
sf.write(f"{out_dir}/{file_name}_slice{idx}.wav", audio_slice, sr)
|
| 43 |
+
start_idx += stride_step
|
| 44 |
+
idx += 1
|
| 45 |
+
return idx
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def extract_features(fpath, full_seq_len=120):
|
| 49 |
+
FPS = 30
|
| 50 |
+
HOP_LENGTH = 512
|
| 51 |
+
SR = FPS * HOP_LENGTH
|
| 52 |
+
|
| 53 |
+
data, _ = librosa.load(fpath, sr=SR)
|
| 54 |
+
envelope = librosa.onset.onset_strength(y=data, sr=SR)
|
| 55 |
+
mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T
|
| 56 |
+
chroma = librosa.feature.chroma_cens(y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12).T
|
| 57 |
+
|
| 58 |
+
peak_idxs = librosa.onset.onset_detect(
|
| 59 |
+
onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
|
| 60 |
+
)
|
| 61 |
+
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 62 |
+
peak_onehot[peak_idxs] = 1.0
|
| 63 |
+
|
| 64 |
+
start_bpm = lr.beat.tempo(y=lr.load(fpath)[0])[0]
|
| 65 |
+
tempo, beat_idxs = librosa.beat.beat_track(
|
| 66 |
+
onset_envelope=envelope, sr=SR, hop_length=HOP_LENGTH,
|
| 67 |
+
start_bpm=start_bpm, tightness=100,
|
| 68 |
+
)
|
| 69 |
+
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 70 |
+
beat_onehot[beat_idxs] = 1.0
|
| 71 |
+
|
| 72 |
+
audio_feature = np.concatenate(
|
| 73 |
+
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
| 74 |
+
axis=-1,
|
| 75 |
+
)
|
| 76 |
+
audio_feature = audio_feature[:4 * FPS]
|
| 77 |
+
return audio_feature
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
|
| 81 |
+
|
| 82 |
+
def stringintcmp_(a, b):
|
| 83 |
+
aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
|
| 84 |
+
ka, kb = key_func(a), key_func(b)
|
| 85 |
+
if aa < bb:
|
| 86 |
+
return -1
|
| 87 |
+
if aa > bb:
|
| 88 |
+
return 1
|
| 89 |
+
if ka < kb:
|
| 90 |
+
return -1
|
| 91 |
+
if ka > kb:
|
| 92 |
+
return 1
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
stringintkey = cmp_to_key(stringintcmp_)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# --- Model loading ---
|
| 99 |
+
|
| 100 |
+
class _Opt:
|
| 101 |
+
"""Minimal config namespace for EDGE model."""
|
| 102 |
+
feature_type = "baseline"
|
| 103 |
+
full_seq_len = 120
|
| 104 |
+
windows = 10
|
| 105 |
+
nfeats = 319
|
| 106 |
+
do_normalize = False
|
| 107 |
+
datasplit = "cross_genre"
|
| 108 |
+
project = "experiments/finedance_seq_120_genre/train"
|
| 109 |
+
exp_name = "finedance_seq_120_genre"
|
| 110 |
+
render_dir = "tmp_renders"
|
| 111 |
+
batch_size = 64
|
| 112 |
+
epochs = 1
|
| 113 |
+
save_interval = 10
|
| 114 |
+
ema_interval = 1
|
| 115 |
+
checkpoint = ""
|
| 116 |
+
wandb_pj_name = "finedance_seq"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_model(checkpoint_path="assets/checkpoints/train-2000.pt"):
|
| 120 |
+
"""Load the EDGE model once. Returns (model, opt)."""
|
| 121 |
+
opt = _Opt()
|
| 122 |
+
model = EDGE(opt, opt.feature_type, checkpoint_path)
|
| 123 |
+
model.eval()
|
| 124 |
+
return model, opt
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _setup_render_args():
|
| 128 |
+
"""Inject render.py global args for MovieMaker."""
|
| 129 |
+
import render as render_module
|
| 130 |
+
render_module.args = argparse.Namespace(
|
| 131 |
+
mode="smplx", fps=30, gpu="0", modir="", save_path=None
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# --- Main pipeline ---
|
| 136 |
+
|
| 137 |
+
def generate(music_path, output_path, model=None, visualizer=None, log_fn=print):
|
| 138 |
+
"""
|
| 139 |
+
Generate a dance video from a music file.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
music_path: Path to input audio (mp3, wav, etc.)
|
| 143 |
+
output_path: Where to save the output mp4
|
| 144 |
+
model: Pre-loaded (model, opt) tuple. If None, loads fresh.
|
| 145 |
+
visualizer: Pre-built MovieMaker instance. If None, creates one.
|
| 146 |
+
log_fn: Callable for status messages (default: print)
|
| 147 |
+
"""
|
| 148 |
+
import shutil
|
| 149 |
+
|
| 150 |
+
music_path = os.path.abspath(music_path)
|
| 151 |
+
songname = os.path.splitext(os.path.basename(music_path))[0]
|
| 152 |
+
|
| 153 |
+
# Step 1: Convert to WAV if needed
|
| 154 |
+
log_fn(f"[1/5] Preparing audio: {os.path.basename(music_path)}")
|
| 155 |
+
temp_root = TemporaryDirectory()
|
| 156 |
+
wav_dir = os.path.join(temp_root.name, "wav")
|
| 157 |
+
os.makedirs(wav_dir)
|
| 158 |
+
wav_path = os.path.join(wav_dir, songname + ".wav")
|
| 159 |
+
|
| 160 |
+
if music_path.lower().endswith(".wav"):
|
| 161 |
+
shutil.copy2(music_path, wav_path)
|
| 162 |
+
else:
|
| 163 |
+
subprocess.run(
|
| 164 |
+
["ffmpeg", "-i", music_path, wav_path, "-y"],
|
| 165 |
+
capture_output=True, check=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Step 2: Slice and extract features
|
| 169 |
+
log_fn("[2/5] Extracting audio features...")
|
| 170 |
+
slice_dir = os.path.join(temp_root.name, "slices")
|
| 171 |
+
os.makedirs(slice_dir)
|
| 172 |
+
stride = 60 / 30 # 2 seconds
|
| 173 |
+
full_seq_len = 120
|
| 174 |
+
slice_audio(wav_path, stride, full_seq_len / 30, slice_dir)
|
| 175 |
+
|
| 176 |
+
file_list = sorted(glob.glob(f"{slice_dir}/*.wav"), key=stringintkey)
|
| 177 |
+
out_length = 30 # seconds
|
| 178 |
+
sample_size = int(out_length / stride) - 1
|
| 179 |
+
|
| 180 |
+
cond_list = []
|
| 181 |
+
for file in tqdm(file_list[:sample_size]):
|
| 182 |
+
reps = extract_features(file)[:full_seq_len]
|
| 183 |
+
cond_list.append(reps)
|
| 184 |
+
cond = torch.from_numpy(np.array(cond_list))
|
| 185 |
+
filenames = file_list[:sample_size]
|
| 186 |
+
|
| 187 |
+
# Step 3: Generate motion
|
| 188 |
+
log_fn("[3/5] Generating dance motion...")
|
| 189 |
+
if model is None:
|
| 190 |
+
edge_model, opt = load_model()
|
| 191 |
+
else:
|
| 192 |
+
edge_model, opt = model
|
| 193 |
+
|
| 194 |
+
motion_dir = os.path.join(temp_root.name, "motions")
|
| 195 |
+
os.makedirs(motion_dir)
|
| 196 |
+
|
| 197 |
+
data_tuple = (None, cond, filenames)
|
| 198 |
+
edge_model.render_sample(
|
| 199 |
+
data_tuple, "gen", temp_root.name, render_count=-1,
|
| 200 |
+
fk_out=motion_dir, mode="long", render=False,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Step 4: Render video
|
| 204 |
+
log_fn("[4/5] Rendering video...")
|
| 205 |
+
motion_file = glob.glob(os.path.join(motion_dir, "*.pkl"))[0]
|
| 206 |
+
modata = motion_data_load_process(motion_file)
|
| 207 |
+
|
| 208 |
+
video_dir = os.path.join(temp_root.name, "video")
|
| 209 |
+
os.makedirs(video_dir)
|
| 210 |
+
|
| 211 |
+
_setup_render_args()
|
| 212 |
+
if visualizer is None:
|
| 213 |
+
visualizer = MovieMaker(save_path=video_dir)
|
| 214 |
+
else:
|
| 215 |
+
visualizer.save_path = video_dir
|
| 216 |
+
visualizer.run(modata, tab=songname, music_file=wav_path)
|
| 217 |
+
|
| 218 |
+
# Step 5: Copy final output
|
| 219 |
+
log_fn("[5/5] Saving output...")
|
| 220 |
+
rendered_file = os.path.join(video_dir, songname + "z.mp4")
|
| 221 |
+
output_path = os.path.abspath(output_path)
|
| 222 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 223 |
+
shutil.move(rendered_file, output_path)
|
| 224 |
+
|
| 225 |
+
temp_root.cleanup()
|
| 226 |
+
log_fn(f"Done! Output saved to: {output_path}")
|
| 227 |
+
return output_path
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
parser = argparse.ArgumentParser(description="Generate a dance video from a music file.")
|
| 232 |
+
parser.add_argument("music", type=str, help="Path to the input music file (mp3, wav, etc.)")
|
| 233 |
+
parser.add_argument("--output", type=str, default=None, help="Output video path (default: output/<songname>_dance.mp4)")
|
| 234 |
+
args = parser.parse_args()
|
| 235 |
+
|
| 236 |
+
if args.output is None:
|
| 237 |
+
songname = os.path.splitext(os.path.basename(args.music))[0]
|
| 238 |
+
args.output = os.path.join("output", f"{songname}_dance.mp4")
|
| 239 |
+
|
| 240 |
+
generate(args.music, args.output)
|
model/adan.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.optim import Optimizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def exists(val):
|
| 8 |
+
return val is not None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adan(Optimizer):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
params,
|
| 15 |
+
lr=1e-3,
|
| 16 |
+
betas=(0.02, 0.08, 0.01),
|
| 17 |
+
eps=1e-8,
|
| 18 |
+
weight_decay=0,
|
| 19 |
+
restart_cond: callable = None,
|
| 20 |
+
):
|
| 21 |
+
assert len(betas) == 3
|
| 22 |
+
|
| 23 |
+
defaults = dict(
|
| 24 |
+
lr=lr,
|
| 25 |
+
betas=betas,
|
| 26 |
+
eps=eps,
|
| 27 |
+
weight_decay=weight_decay,
|
| 28 |
+
restart_cond=restart_cond,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
super().__init__(params, defaults)
|
| 32 |
+
|
| 33 |
+
def step(self, closure=None):
|
| 34 |
+
loss = None
|
| 35 |
+
|
| 36 |
+
if exists(closure):
|
| 37 |
+
loss = closure()
|
| 38 |
+
|
| 39 |
+
for group in self.param_groups:
|
| 40 |
+
|
| 41 |
+
lr = group["lr"]
|
| 42 |
+
beta1, beta2, beta3 = group["betas"]
|
| 43 |
+
weight_decay = group["weight_decay"]
|
| 44 |
+
eps = group["eps"]
|
| 45 |
+
restart_cond = group["restart_cond"]
|
| 46 |
+
|
| 47 |
+
for p in group["params"]:
|
| 48 |
+
if not exists(p.grad):
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
data, grad = p.data, p.grad.data
|
| 52 |
+
assert not grad.is_sparse
|
| 53 |
+
|
| 54 |
+
state = self.state[p]
|
| 55 |
+
|
| 56 |
+
if len(state) == 0:
|
| 57 |
+
state["step"] = 0
|
| 58 |
+
state["prev_grad"] = torch.zeros_like(grad)
|
| 59 |
+
state["m"] = torch.zeros_like(grad)
|
| 60 |
+
state["v"] = torch.zeros_like(grad)
|
| 61 |
+
state["n"] = torch.zeros_like(grad)
|
| 62 |
+
|
| 63 |
+
step, m, v, n, prev_grad = (
|
| 64 |
+
state["step"],
|
| 65 |
+
state["m"],
|
| 66 |
+
state["v"],
|
| 67 |
+
state["n"],
|
| 68 |
+
state["prev_grad"],
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if step > 0:
|
| 72 |
+
prev_grad = state["prev_grad"]
|
| 73 |
+
|
| 74 |
+
# main algorithm
|
| 75 |
+
|
| 76 |
+
m.mul_(1 - beta1).add_(grad, alpha=beta1)
|
| 77 |
+
|
| 78 |
+
grad_diff = grad - prev_grad
|
| 79 |
+
|
| 80 |
+
v.mul_(1 - beta2).add_(grad_diff, alpha=beta2)
|
| 81 |
+
|
| 82 |
+
next_n = (grad + (1 - beta2) * grad_diff) ** 2
|
| 83 |
+
|
| 84 |
+
n.mul_(1 - beta3).add_(next_n, alpha=beta3)
|
| 85 |
+
|
| 86 |
+
# bias correction terms
|
| 87 |
+
|
| 88 |
+
step += 1
|
| 89 |
+
|
| 90 |
+
correct_m, correct_v, correct_n = map(
|
| 91 |
+
lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# gradient step
|
| 95 |
+
|
| 96 |
+
def grad_step_(data, m, v, n):
|
| 97 |
+
weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)
|
| 98 |
+
|
| 99 |
+
denom = 1 + weight_decay * lr
|
| 100 |
+
|
| 101 |
+
data.addcmul_(
|
| 102 |
+
weighted_step_size,
|
| 103 |
+
(m * correct_m + (1 - beta2) * v * correct_v),
|
| 104 |
+
value=-1.0,
|
| 105 |
+
).div_(denom)
|
| 106 |
+
|
| 107 |
+
grad_step_(data, m, v, n)
|
| 108 |
+
|
| 109 |
+
# restart condition
|
| 110 |
+
|
| 111 |
+
if exists(restart_cond) and restart_cond(state):
|
| 112 |
+
m.data.copy_(grad)
|
| 113 |
+
v.zero_()
|
| 114 |
+
n.data.copy_(grad ** 2)
|
| 115 |
+
|
| 116 |
+
grad_step_(data, m, v, n)
|
| 117 |
+
|
| 118 |
+
# set new incremented step
|
| 119 |
+
|
| 120 |
+
prev_grad.copy_(grad)
|
| 121 |
+
state["step"] = step
|
| 122 |
+
|
| 123 |
+
return loss
|
model/diffusion.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import reduce
|
| 12 |
+
from p_tqdm import p_map
|
| 13 |
+
from pytorch3d.transforms import (axis_angle_to_quaternion,
|
| 14 |
+
quaternion_to_axis_angle)
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from dataset.quaternion import ax_from_6v, quat_slerp
|
| 18 |
+
from vis import skeleton_render
|
| 19 |
+
from vis import SMPLX_Skeleton
|
| 20 |
+
from dataset.preprocess import My_Normalizer as Normalizer
|
| 21 |
+
|
| 22 |
+
from .utils import extract, make_beta_schedule
|
| 23 |
+
|
| 24 |
+
def identity(t, *args, **kwargs):
|
| 25 |
+
return t
|
| 26 |
+
|
| 27 |
+
class EMA:
|
| 28 |
+
def __init__(self, beta):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.beta = beta
|
| 31 |
+
|
| 32 |
+
def update_model_average(self, ma_model, current_model):
|
| 33 |
+
for current_params, ma_params in zip(
|
| 34 |
+
current_model.parameters(), ma_model.parameters()
|
| 35 |
+
):
|
| 36 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
| 37 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
| 38 |
+
|
| 39 |
+
def update_average(self, old, new):
|
| 40 |
+
if old is None:
|
| 41 |
+
return new
|
| 42 |
+
return old * self.beta + (1 - self.beta) * new
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class GaussianDiffusion(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
model,
|
| 49 |
+
opt,
|
| 50 |
+
horizon,
|
| 51 |
+
repr_dim,
|
| 52 |
+
smplx_model,
|
| 53 |
+
n_timestep=1000,
|
| 54 |
+
schedule="linear",
|
| 55 |
+
loss_type="l1",
|
| 56 |
+
clip_denoised=True,
|
| 57 |
+
predict_epsilon=True,
|
| 58 |
+
guidance_weight=3,
|
| 59 |
+
use_p2=False,
|
| 60 |
+
cond_drop_prob=0.2,
|
| 61 |
+
do_normalize=False,
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.horizon = horizon
|
| 65 |
+
self.transition_dim = repr_dim
|
| 66 |
+
self.model = model
|
| 67 |
+
self.ema = EMA(0.9999)
|
| 68 |
+
self.master_model = copy.deepcopy(self.model)
|
| 69 |
+
self.normalizer = None
|
| 70 |
+
self.do_normalize = do_normalize
|
| 71 |
+
self.opt = opt
|
| 72 |
+
|
| 73 |
+
self.cond_drop_prob = cond_drop_prob
|
| 74 |
+
|
| 75 |
+
# make a SMPL instance for FK module
|
| 76 |
+
self.smplx_fk = smplx_model
|
| 77 |
+
|
| 78 |
+
betas = torch.Tensor(
|
| 79 |
+
make_beta_schedule(schedule=schedule, n_timestep=n_timestep)
|
| 80 |
+
)
|
| 81 |
+
alphas = 1.0 - betas
|
| 82 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
| 83 |
+
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
| 84 |
+
|
| 85 |
+
self.n_timestep = int(n_timestep)
|
| 86 |
+
self.clip_denoised = clip_denoised
|
| 87 |
+
self.predict_epsilon = predict_epsilon
|
| 88 |
+
|
| 89 |
+
self.register_buffer("betas", betas)
|
| 90 |
+
self.register_buffer("alphas_cumprod", alphas_cumprod)
|
| 91 |
+
self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
| 92 |
+
|
| 93 |
+
self.guidance_weight = guidance_weight
|
| 94 |
+
|
| 95 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 96 |
+
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
| 97 |
+
self.register_buffer(
|
| 98 |
+
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
|
| 99 |
+
)
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
|
| 102 |
+
)
|
| 103 |
+
self.register_buffer(
|
| 104 |
+
"sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
|
| 105 |
+
)
|
| 106 |
+
self.register_buffer(
|
| 107 |
+
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 111 |
+
posterior_variance = (
|
| 112 |
+
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 113 |
+
)
|
| 114 |
+
self.register_buffer("posterior_variance", posterior_variance)
|
| 115 |
+
|
| 116 |
+
## log calculation clipped because the posterior variance
|
| 117 |
+
## is 0 at the beginning of the diffusion chain
|
| 118 |
+
self.register_buffer(
|
| 119 |
+
"posterior_log_variance_clipped",
|
| 120 |
+
torch.log(torch.clamp(posterior_variance, min=1e-20)),
|
| 121 |
+
)
|
| 122 |
+
self.register_buffer(
|
| 123 |
+
"posterior_mean_coef1",
|
| 124 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
|
| 125 |
+
)
|
| 126 |
+
self.register_buffer(
|
| 127 |
+
"posterior_mean_coef2",
|
| 128 |
+
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# p2 weighting
|
| 132 |
+
self.p2_loss_weight_k = 1
|
| 133 |
+
self.p2_loss_weight_gamma = 0.5 if use_p2 else 0
|
| 134 |
+
self.register_buffer(
|
| 135 |
+
"p2_loss_weight",
|
| 136 |
+
(self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
|
| 137 |
+
** -self.p2_loss_weight_gamma,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
## get loss coefficients and initialize objective
|
| 141 |
+
self.loss_fn = F.mse_loss if loss_type == "l2" else F.l1_loss
|
| 142 |
+
|
| 143 |
+
# ------------------------------------------ sampling ------------------------------------------#
|
| 144 |
+
|
| 145 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 146 |
+
"""
|
| 147 |
+
if self.predict_epsilon, model output is (scaled) noise;
|
| 148 |
+
otherwise, model predicts x0 directly
|
| 149 |
+
"""
|
| 150 |
+
if self.predict_epsilon:
|
| 151 |
+
return (
|
| 152 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 153 |
+
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
return noise
|
| 157 |
+
|
| 158 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 159 |
+
return (
|
| 160 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
| 161 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def model_predictions(self, x, cond, t, weight=None, clip_x_start = False):
|
| 165 |
+
weight = weight if weight is not None else self.guidance_weight
|
| 166 |
+
model_output = self.model.guided_forward(x, cond, t, weight)
|
| 167 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
| 168 |
+
|
| 169 |
+
x_start = model_output
|
| 170 |
+
x_start = maybe_clip(x_start)
|
| 171 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 172 |
+
|
| 173 |
+
return pred_noise, x_start
|
| 174 |
+
|
| 175 |
+
def q_posterior(self, x_start, x_t, t):
|
| 176 |
+
posterior_mean = (
|
| 177 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 178 |
+
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 179 |
+
)
|
| 180 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 181 |
+
posterior_log_variance_clipped = extract(
|
| 182 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 183 |
+
)
|
| 184 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 185 |
+
|
| 186 |
+
def p_mean_variance(self, x, cond, t):
|
| 187 |
+
# guidance clipping
|
| 188 |
+
if t[0] > 1.0 * self.n_timestep:
|
| 189 |
+
weight = min(self.guidance_weight, 0)
|
| 190 |
+
elif t[0] < 0.1 * self.n_timestep:
|
| 191 |
+
weight = min(self.guidance_weight, 1)
|
| 192 |
+
else:
|
| 193 |
+
weight = self.guidance_weight
|
| 194 |
+
|
| 195 |
+
x_recon = self.predict_start_from_noise(
|
| 196 |
+
x, t=t, noise=self.model.guided_forward(x, cond, t, weight)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if self.clip_denoised:
|
| 200 |
+
x_recon.clamp_(-1.0, 1.0)
|
| 201 |
+
else:
|
| 202 |
+
assert RuntimeError()
|
| 203 |
+
|
| 204 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
| 205 |
+
x_start=x_recon, x_t=x, t=t
|
| 206 |
+
)
|
| 207 |
+
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def p_sample(self, x, cond, t):
|
| 211 |
+
b, *_, device = *x.shape, x.device
|
| 212 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
|
| 213 |
+
x=x, cond=cond, t=t
|
| 214 |
+
)
|
| 215 |
+
noise = torch.randn_like(model_mean)
|
| 216 |
+
# no noise when t == 0
|
| 217 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(
|
| 218 |
+
b, *((1,) * (len(noise.shape) - 1))
|
| 219 |
+
)
|
| 220 |
+
x_out = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
| 221 |
+
return x_out, x_start
|
| 222 |
+
|
| 223 |
+
@torch.no_grad()
|
| 224 |
+
def p_sample_loop(
|
| 225 |
+
self,
|
| 226 |
+
shape,
|
| 227 |
+
cond,
|
| 228 |
+
noise=None,
|
| 229 |
+
constraint=None,
|
| 230 |
+
return_diffusion=False,
|
| 231 |
+
start_point=None,
|
| 232 |
+
):
|
| 233 |
+
device = self.betas.device
|
| 234 |
+
|
| 235 |
+
# default to diffusion over whole timescale
|
| 236 |
+
start_point = self.n_timestep if start_point is None else start_point
|
| 237 |
+
batch_size = shape[0]
|
| 238 |
+
x = torch.randn(shape, device=device) if noise is None else noise.to(device)
|
| 239 |
+
cond = cond.to(device)
|
| 240 |
+
|
| 241 |
+
if return_diffusion:
|
| 242 |
+
diffusion = [x]
|
| 243 |
+
|
| 244 |
+
for i in tqdm(reversed(range(0, start_point))):
|
| 245 |
+
# fill with i
|
| 246 |
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
| 247 |
+
x, _ = self.p_sample(x, cond, timesteps)
|
| 248 |
+
|
| 249 |
+
if return_diffusion:
|
| 250 |
+
diffusion.append(x)
|
| 251 |
+
|
| 252 |
+
if return_diffusion:
|
| 253 |
+
return x, diffusion
|
| 254 |
+
else:
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
@torch.no_grad()
|
| 258 |
+
def ddim_sample(self, shape, cond, **kwargs):
|
| 259 |
+
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
|
| 260 |
+
|
| 261 |
+
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
| 262 |
+
times = list(reversed(times.int().tolist()))
|
| 263 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 264 |
+
|
| 265 |
+
x = torch.randn(shape, device = device)
|
| 266 |
+
cond = cond.to(device)
|
| 267 |
+
|
| 268 |
+
x_start = None
|
| 269 |
+
|
| 270 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
| 271 |
+
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
|
| 272 |
+
pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, clip_x_start = self.clip_denoised)
|
| 273 |
+
|
| 274 |
+
if time_next < 0:
|
| 275 |
+
x = x_start
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
alpha = self.alphas_cumprod[time]
|
| 279 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 280 |
+
|
| 281 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
| 282 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 283 |
+
|
| 284 |
+
noise = torch.randn_like(x)
|
| 285 |
+
|
| 286 |
+
x = x_start * alpha_next.sqrt() + \
|
| 287 |
+
c * pred_noise + \
|
| 288 |
+
sigma * noise
|
| 289 |
+
return x
|
| 290 |
+
|
| 291 |
+
@torch.no_grad()
|
| 292 |
+
def long_ddim_sample(self, shape, cond, **kwargs):
|
| 293 |
+
batch, device, total_timesteps, sampling_timesteps, eta = shape[0], self.betas.device, self.n_timestep, 50, 1
|
| 294 |
+
|
| 295 |
+
if batch == 1:
|
| 296 |
+
return self.ddim_sample(shape, cond)
|
| 297 |
+
|
| 298 |
+
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
| 299 |
+
times = list(reversed(times.int().tolist()))
|
| 300 |
+
weights = np.clip(np.linspace(0, self.guidance_weight * 2, sampling_timesteps), None, self.guidance_weight)
|
| 301 |
+
time_pairs = list(zip(times[:-1], times[1:], weights)) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 302 |
+
|
| 303 |
+
x = torch.randn(shape, device = device)
|
| 304 |
+
cond = cond.to(device)
|
| 305 |
+
|
| 306 |
+
assert batch > 1
|
| 307 |
+
assert x.shape[1] % 2 == 0
|
| 308 |
+
half = x.shape[1] // 2
|
| 309 |
+
|
| 310 |
+
x_start = None
|
| 311 |
+
|
| 312 |
+
for time, time_next, weight in tqdm(time_pairs, desc = 'sampling loop time step'):
|
| 313 |
+
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
|
| 314 |
+
pred_noise, x_start, *_ = self.model_predictions(x, cond, time_cond, weight=weight, clip_x_start = self.clip_denoised)
|
| 315 |
+
|
| 316 |
+
if time_next < 0:
|
| 317 |
+
x = x_start
|
| 318 |
+
continue
|
| 319 |
+
|
| 320 |
+
alpha = self.alphas_cumprod[time]
|
| 321 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 322 |
+
|
| 323 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
| 324 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 325 |
+
|
| 326 |
+
noise = torch.randn_like(x)
|
| 327 |
+
|
| 328 |
+
x = x_start * alpha_next.sqrt() + \
|
| 329 |
+
c * pred_noise + \
|
| 330 |
+
sigma * noise
|
| 331 |
+
|
| 332 |
+
if time > 0:
|
| 333 |
+
# the first half of each sequence is the second half of the previous one
|
| 334 |
+
x[1:, :half] = x[:-1, half:]
|
| 335 |
+
return x
|
| 336 |
+
|
| 337 |
+
@torch.no_grad()
|
| 338 |
+
def inpaint_loop(
|
| 339 |
+
self,
|
| 340 |
+
shape,
|
| 341 |
+
cond,
|
| 342 |
+
noise=None,
|
| 343 |
+
constraint=None,
|
| 344 |
+
return_diffusion=False,
|
| 345 |
+
start_point=None,
|
| 346 |
+
):
|
| 347 |
+
device = self.betas.device
|
| 348 |
+
|
| 349 |
+
batch_size = shape[0]
|
| 350 |
+
x = torch.randn(shape, device=device) if noise is None else noise.to(device)
|
| 351 |
+
cond = cond.to(device)
|
| 352 |
+
if return_diffusion:
|
| 353 |
+
diffusion = [x]
|
| 354 |
+
|
| 355 |
+
mask = constraint["mask"].to(device) # batch x horizon x channels
|
| 356 |
+
value = constraint["value"].to(device) # batch x horizon x channels
|
| 357 |
+
|
| 358 |
+
start_point = self.n_timestep if start_point is None else start_point
|
| 359 |
+
for i in tqdm(reversed(range(0, start_point))):
|
| 360 |
+
# fill with i
|
| 361 |
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
| 362 |
+
|
| 363 |
+
# sample x from step i to step i-1
|
| 364 |
+
x, _ = self.p_sample(x, cond, timesteps)
|
| 365 |
+
# enforce constraint between each denoising step
|
| 366 |
+
value_ = self.q_sample(value, timesteps - 1) if (i > 0) else x
|
| 367 |
+
x = value_ * mask + (1.0 - mask) * x
|
| 368 |
+
|
| 369 |
+
if return_diffusion:
|
| 370 |
+
diffusion.append(x)
|
| 371 |
+
|
| 372 |
+
if return_diffusion:
|
| 373 |
+
return x, diffusion
|
| 374 |
+
else:
|
| 375 |
+
return x
|
| 376 |
+
|
| 377 |
+
@torch.no_grad()
|
| 378 |
+
def long_inpaint_loop(
|
| 379 |
+
self,
|
| 380 |
+
shape,
|
| 381 |
+
cond,
|
| 382 |
+
noise=None,
|
| 383 |
+
constraint=None,
|
| 384 |
+
return_diffusion=False,
|
| 385 |
+
start_point=None,
|
| 386 |
+
):
|
| 387 |
+
device = self.betas.device
|
| 388 |
+
|
| 389 |
+
batch_size = shape[0]
|
| 390 |
+
x = torch.randn(shape, device=device) if noise is None else noise.to(device)
|
| 391 |
+
cond = cond.to(device)
|
| 392 |
+
if return_diffusion:
|
| 393 |
+
diffusion = [x]
|
| 394 |
+
|
| 395 |
+
assert x.shape[1] % 2 == 0
|
| 396 |
+
if batch_size == 1:
|
| 397 |
+
# there's no continuation to do, just do normal
|
| 398 |
+
return self.p_sample_loop(
|
| 399 |
+
shape,
|
| 400 |
+
cond,
|
| 401 |
+
noise=noise,
|
| 402 |
+
constraint=constraint,
|
| 403 |
+
return_diffusion=return_diffusion,
|
| 404 |
+
start_point=start_point,
|
| 405 |
+
)
|
| 406 |
+
assert batch_size > 1
|
| 407 |
+
half = x.shape[1] // 2
|
| 408 |
+
|
| 409 |
+
start_point = self.n_timestep if start_point is None else start_point
|
| 410 |
+
for i in tqdm(reversed(range(0, start_point))):
|
| 411 |
+
# fill with i
|
| 412 |
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
| 413 |
+
|
| 414 |
+
# sample x from step i to step i-1
|
| 415 |
+
x, _ = self.p_sample(x, cond, timesteps)
|
| 416 |
+
# enforce constraint between each denoising step
|
| 417 |
+
if i > 0:
|
| 418 |
+
# the first half of each sequence is the second half of the previous one
|
| 419 |
+
x[1:, :half] = x[:-1, half:]
|
| 420 |
+
|
| 421 |
+
if return_diffusion:
|
| 422 |
+
diffusion.append(x)
|
| 423 |
+
|
| 424 |
+
if return_diffusion:
|
| 425 |
+
return x, diffusion
|
| 426 |
+
else:
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
@torch.no_grad()
|
| 430 |
+
def conditional_sample(
|
| 431 |
+
self, shape, cond, constraint=None, *args, horizon=None, **kwargs
|
| 432 |
+
):
|
| 433 |
+
"""
|
| 434 |
+
conditions : [ (time, state), ... ]
|
| 435 |
+
"""
|
| 436 |
+
device = self.betas.device
|
| 437 |
+
horizon = horizon or self.horizon
|
| 438 |
+
|
| 439 |
+
return self.p_sample_loop(shape, cond, *args, **kwargs)
|
| 440 |
+
|
| 441 |
+
# ------------------------------------------ training ------------------------------------------#
|
| 442 |
+
|
| 443 |
+
def q_sample(self, x_start, t, noise=None):
|
| 444 |
+
if noise is None:
|
| 445 |
+
noise = torch.randn_like(x_start)
|
| 446 |
+
|
| 447 |
+
sample = (
|
| 448 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 449 |
+
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
return sample
|
| 453 |
+
|
| 454 |
+
def p_losses(self, x_start, cond, t):
|
| 455 |
+
noise = torch.randn_like(x_start)
|
| 456 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 将x0加噪到xt
|
| 457 |
+
|
| 458 |
+
# reconstruct
|
| 459 |
+
x_recon = self.model(x_noisy, cond, t, cond_drop_prob=self.cond_drop_prob)
|
| 460 |
+
assert noise.shape == x_recon.shape
|
| 461 |
+
|
| 462 |
+
model_out = x_recon
|
| 463 |
+
if self.predict_epsilon:
|
| 464 |
+
target = noise
|
| 465 |
+
else:
|
| 466 |
+
target = x_start
|
| 467 |
+
|
| 468 |
+
# full reconstruction loss
|
| 469 |
+
loss = self.loss_fn(model_out, target, reduction="none") # mse loss
|
| 470 |
+
loss = reduce(loss, "b ... -> b (...)", "mean")
|
| 471 |
+
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
|
| 472 |
+
|
| 473 |
+
# split off contact from the rest
|
| 474 |
+
_, model_out_ = torch.split(
|
| 475 |
+
model_out, (4, model_out.shape[2] - 4), dim=2 # 前4维是foot contact
|
| 476 |
+
)
|
| 477 |
+
_, target_ = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc
|
| 478 |
+
|
| 479 |
+
# velocity loss
|
| 480 |
+
target_v = target_[:, 1:] - target_[:, :-1]
|
| 481 |
+
model_out_v = model_out_[:, 1:] - model_out_[:, :-1]
|
| 482 |
+
v_loss = self.loss_fn(model_out_v, target_v, reduction="none")
|
| 483 |
+
v_loss = reduce(v_loss, "b ... -> b (...)", "mean")
|
| 484 |
+
v_loss = v_loss * extract(self.p2_loss_weight, t, v_loss.shape)
|
| 485 |
+
|
| 486 |
+
# FK loss
|
| 487 |
+
b, s, c = model_out.shape
|
| 488 |
+
model_contact, model_out = torch.split(model_out, (4, model_out.shape[2] - 4), dim=2)
|
| 489 |
+
target_contact, target = torch.split(target, (4, target.shape[2] - 4), dim=2) # b, length, jxc
|
| 490 |
+
model_x = model_out[:, :, :3] # root position
|
| 491 |
+
model_q = ax_from_6v(model_out[:, :, 3:].reshape(b, s, -1, 6))
|
| 492 |
+
target_x = target[:, :, :3]
|
| 493 |
+
target_q = ax_from_6v(target[:, :, 3:].reshape(b, s, -1, 6))
|
| 494 |
+
b, s, nums, c_ = model_q.shape
|
| 495 |
+
|
| 496 |
+
if self.opt.nfeats == 139 or self.opt.nfeats==135:
|
| 497 |
+
model_xp = self.smplx_fk.forward(model_q, model_x)
|
| 498 |
+
target_xp = self.smplx_fk.forward(target_q, target_x)
|
| 499 |
+
else:
|
| 500 |
+
model_q = model_q.view(b*s, -1)
|
| 501 |
+
target_q = target_q.view(b*s, -1)
|
| 502 |
+
model_x = model_x.view(-1, 3)
|
| 503 |
+
target_x = target_x.view(-1, 3)
|
| 504 |
+
model_xp = self.smplx_fk.forward(model_q, model_x)
|
| 505 |
+
target_xp = self.smplx_fk.forward(target_q, target_x)
|
| 506 |
+
model_xp = model_xp.view(b, s, -1, 3)
|
| 507 |
+
target_xp = target_xp.view(b, s, -1, 3)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
fk_loss = self.loss_fn(model_xp, target_xp, reduction="none")
|
| 512 |
+
fk_loss = reduce(fk_loss, "b ... -> b (...)", "mean")
|
| 513 |
+
fk_loss = fk_loss * extract(self.p2_loss_weight, t, fk_loss.shape)
|
| 514 |
+
|
| 515 |
+
# foot skate loss
|
| 516 |
+
foot_idx = [7, 8, 10, 11]
|
| 517 |
+
# find static indices consistent with model's own predictions
|
| 518 |
+
static_idx = model_contact > 0.95 # N x S x 4
|
| 519 |
+
model_feet = model_xp[:, :, foot_idx] # foot positions (N, S, 4, 3)
|
| 520 |
+
model_foot_v = torch.zeros_like(model_feet)
|
| 521 |
+
model_foot_v[:, :-1] = (
|
| 522 |
+
model_feet[:, 1:, :, :] - model_feet[:, :-1, :, :]
|
| 523 |
+
) # (N, S-1, 4, 3)
|
| 524 |
+
model_foot_v[~static_idx] = 0
|
| 525 |
+
foot_loss = self.loss_fn(
|
| 526 |
+
model_foot_v, torch.zeros_like(model_foot_v), reduction="none"
|
| 527 |
+
)
|
| 528 |
+
foot_loss = reduce(foot_loss, "b ... -> b (...)", "mean")
|
| 529 |
+
losses = (
|
| 530 |
+
0.636 * loss.mean(),
|
| 531 |
+
2.964 * v_loss.mean(),
|
| 532 |
+
0.646 * fk_loss.mean(),
|
| 533 |
+
10.942 * foot_loss.mean(),
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
return sum(losses), losses
|
| 537 |
+
|
| 538 |
+
def loss(self, x, cond, t_override=None):
|
| 539 |
+
batch_size = len(x)
|
| 540 |
+
if t_override is None:
|
| 541 |
+
t = torch.randint(0, self.n_timestep, (batch_size,), device=x.device).long()
|
| 542 |
+
else:
|
| 543 |
+
t = torch.full((batch_size,), t_override, device=x.device).long()
|
| 544 |
+
return self.p_losses(x, cond, t)
|
| 545 |
+
|
| 546 |
+
def forward(self, x, cond, t_override=None):
|
| 547 |
+
return self.loss(x, cond, t_override)
|
| 548 |
+
|
| 549 |
+
def partial_denoise(self, x, cond, t):
|
| 550 |
+
x_noisy = self.noise_to_t(x, t)
|
| 551 |
+
return self.p_sample_loop(x.shape, cond, noise=x_noisy, start_point=t)
|
| 552 |
+
|
| 553 |
+
def noise_to_t(self, x, timestep):
|
| 554 |
+
batch_size = len(x)
|
| 555 |
+
t = torch.full((batch_size,), timestep, device=x.device).long()
|
| 556 |
+
return self.q_sample(x, t) if timestep > 0 else x
|
| 557 |
+
|
| 558 |
+
def smplxmodel_fk(self, local_q, root_pos): # input
|
| 559 |
+
b, s, nums, c = local_q.shape
|
| 560 |
+
local_q = local_q.view(b*s, -1)
|
| 561 |
+
full_pose = self.smplx_model(
|
| 562 |
+
betas = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32),
|
| 563 |
+
transl = root_pos.view(b*s, -1), # global translation
|
| 564 |
+
global_orient = local_q[:, :3],
|
| 565 |
+
body_pose = local_q[:, 3:66], # 21
|
| 566 |
+
jaw_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
|
| 567 |
+
leye_pose = torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
|
| 568 |
+
reye_pose= torch.zeros([b*s, 3], device=local_q.device, dtype=torch.float32), # 1
|
| 569 |
+
left_hand_pose = local_q[:, 66:111], # 15
|
| 570 |
+
right_hand_pose = local_q[:, 111:], # 15
|
| 571 |
+
expression = torch.zeros([b*s, 10], device=local_q.device, dtype=torch.float32),
|
| 572 |
+
return_verts = False
|
| 573 |
+
)
|
| 574 |
+
full_pose = full_pose.joints.view(b, s, -1, 3) # b, s, 55, 3
|
| 575 |
+
return full_pose
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def render_sample(
|
| 579 |
+
self,
|
| 580 |
+
shape,
|
| 581 |
+
cond,
|
| 582 |
+
normalizer,
|
| 583 |
+
epoch,
|
| 584 |
+
render_out,
|
| 585 |
+
fk_out=None,
|
| 586 |
+
name=None,
|
| 587 |
+
sound=True,
|
| 588 |
+
mode="normal",
|
| 589 |
+
noise=None,
|
| 590 |
+
constraint=None,
|
| 591 |
+
sound_folder="ood_sliced",
|
| 592 |
+
start_point=None,
|
| 593 |
+
render=True,
|
| 594 |
+
# do_normalize=True,
|
| 595 |
+
):
|
| 596 |
+
if isinstance(shape, tuple):
|
| 597 |
+
if mode == "inpaint":
|
| 598 |
+
func_class = self.inpaint_loop
|
| 599 |
+
elif mode == "normal":
|
| 600 |
+
func_class = self.ddim_sample
|
| 601 |
+
elif mode == "long":
|
| 602 |
+
func_class = self.long_ddim_sample
|
| 603 |
+
else:
|
| 604 |
+
assert False, "Unrecognized inference mode"
|
| 605 |
+
samples = (
|
| 606 |
+
func_class(
|
| 607 |
+
shape,
|
| 608 |
+
cond,
|
| 609 |
+
noise=noise,
|
| 610 |
+
constraint=constraint,
|
| 611 |
+
start_point=start_point,
|
| 612 |
+
)
|
| 613 |
+
.detach()
|
| 614 |
+
.cpu()
|
| 615 |
+
)
|
| 616 |
+
else:
|
| 617 |
+
samples = shape
|
| 618 |
+
|
| 619 |
+
if self.do_normalize:
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
samples = normalizer.unnormalize(samples)
|
| 622 |
+
|
| 623 |
+
if samples.shape[2] == 319 or samples.shape[2] == 151 or samples.shape[2] == 139: # debug if samples.shape[2] == 151:
|
| 624 |
+
sample_contact, samples = torch.split(
|
| 625 |
+
samples, (4, samples.shape[2] - 4), dim=2
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
sample_contact = None
|
| 629 |
+
# do the FK all at once
|
| 630 |
+
b, s, c = samples.shape
|
| 631 |
+
pos = samples[:, :, :3].to(cond.device) # np.zeros((sample.shape[0], 3))
|
| 632 |
+
q = samples[:, :, 3:].reshape(b, s, -1, 6) # debug 24
|
| 633 |
+
# go 6d to ax
|
| 634 |
+
q = ax_from_6v(q).to(cond.device)
|
| 635 |
+
|
| 636 |
+
if self.opt.nfeats == 139 or self.opt.nfeats==135:
|
| 637 |
+
reshape_size = 66
|
| 638 |
+
else:
|
| 639 |
+
reshape_size = 156
|
| 640 |
+
|
| 641 |
+
if mode == "long":
|
| 642 |
+
b, s, c1, c2 = q.shape
|
| 643 |
+
assert s % 2 == 0
|
| 644 |
+
half = s // 2
|
| 645 |
+
if b > 1:
|
| 646 |
+
# if long mode, stitch position using linear interp
|
| 647 |
+
fade_out = torch.ones((1, s, 1)).to(pos.device)
|
| 648 |
+
fade_in = torch.ones((1, s, 1)).to(pos.device)
|
| 649 |
+
fade_out[:, half:, :] = torch.linspace(1, 0, half)[None, :, None].to(
|
| 650 |
+
pos.device
|
| 651 |
+
)
|
| 652 |
+
fade_in[:, :half, :] = torch.linspace(0, 1, half)[None, :, None].to(
|
| 653 |
+
pos.device
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
pos[:-1] *= fade_out
|
| 657 |
+
pos[1:] *= fade_in
|
| 658 |
+
|
| 659 |
+
full_pos = torch.zeros((s + half * (b - 1), 3)).to(pos.device)
|
| 660 |
+
idx = 0
|
| 661 |
+
for pos_slice in pos:
|
| 662 |
+
full_pos[idx : idx + s] += pos_slice
|
| 663 |
+
idx += half
|
| 664 |
+
|
| 665 |
+
# stitch joint angles with slerp
|
| 666 |
+
slerp_weight = torch.linspace(0, 1, half)[None, :, None].to(pos.device)
|
| 667 |
+
|
| 668 |
+
left, right = q[:-1, half:], q[1:, :half]
|
| 669 |
+
# convert to quat
|
| 670 |
+
left, right = (
|
| 671 |
+
axis_angle_to_quaternion(left),
|
| 672 |
+
axis_angle_to_quaternion(right),
|
| 673 |
+
)
|
| 674 |
+
merged = quat_slerp(left, right, slerp_weight) # (b-1) x half x ...
|
| 675 |
+
# convert back
|
| 676 |
+
merged = quaternion_to_axis_angle(merged)
|
| 677 |
+
|
| 678 |
+
full_q = torch.zeros((s + half * (b - 1), c1, c2)).to(pos.device)
|
| 679 |
+
full_q[:half] += q[0, :half]
|
| 680 |
+
idx = half
|
| 681 |
+
for q_slice in merged:
|
| 682 |
+
full_q[idx : idx + half] += q_slice
|
| 683 |
+
idx += half
|
| 684 |
+
full_q[idx : idx + half] += q[-1, half:]
|
| 685 |
+
|
| 686 |
+
# unsqueeze for fk
|
| 687 |
+
full_pos = full_pos.unsqueeze(0)
|
| 688 |
+
full_q = full_q.unsqueeze(0)
|
| 689 |
+
else:
|
| 690 |
+
full_pos = pos
|
| 691 |
+
full_q = q
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
if fk_out is not None:
|
| 695 |
+
outname = f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.pkl' # f'{epoch}_{"_".join(name)}.pkl' #
|
| 696 |
+
Path(fk_out).mkdir(parents=True, exist_ok=True)
|
| 697 |
+
pickle.dump(
|
| 698 |
+
{
|
| 699 |
+
"smpl_poses": full_q.squeeze(0).reshape((-1, reshape_size)).cpu().numpy(), # local rotations
|
| 700 |
+
"smpl_trans": full_pos.squeeze(0).cpu().numpy(), # root translation
|
| 701 |
+
# "full_pose": full_pose[0], # 3d positions
|
| 702 |
+
},
|
| 703 |
+
open(os.path.join(fk_out, outname), "wb"),
|
| 704 |
+
)
|
| 705 |
+
return
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
sample_contact = (
|
| 709 |
+
sample_contact.detach().cpu().numpy()
|
| 710 |
+
if sample_contact is not None
|
| 711 |
+
else None
|
| 712 |
+
)
|
| 713 |
+
def inner(xx):
|
| 714 |
+
num, pose = xx
|
| 715 |
+
filename = name[num] if name is not None else None
|
| 716 |
+
contact = sample_contact[num] if sample_contact is not None else None
|
| 717 |
+
skeleton_render(
|
| 718 |
+
pose,
|
| 719 |
+
epoch=f"e{epoch}_b{num}",
|
| 720 |
+
out=render_out,
|
| 721 |
+
name=filename,
|
| 722 |
+
sound=sound,
|
| 723 |
+
contact=contact,
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# p_map(inner, enumerate(poses)) # poses: 2, 150, 52, 3
|
| 727 |
+
# print("4")
|
| 728 |
+
if fk_out is not None and mode != "long":
|
| 729 |
+
Path(fk_out).mkdir(parents=True, exist_ok=True)
|
| 730 |
+
# for num, (qq, pos_, filename, pose) in enumerate(zip(q, pos, name, poses)):
|
| 731 |
+
for num, (qq, pos_, filename) in enumerate(zip(q, pos, name)):
|
| 732 |
+
filename = os.path.basename(filename).split(".")[0]
|
| 733 |
+
outname = f"{epoch}_{num}_{filename}.pkl"
|
| 734 |
+
pickle.dump(
|
| 735 |
+
{
|
| 736 |
+
"smpl_poses": qq.reshape((-1, reshape_size)).cpu().numpy(),
|
| 737 |
+
"smpl_trans": pos_.cpu().numpy(),
|
| 738 |
+
# "full_pose": pose,
|
| 739 |
+
},
|
| 740 |
+
open(f"{fk_out}/{outname}", "wb"),
|
| 741 |
+
)
|
model/model.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange, reduce, repeat
|
| 7 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from model.rotary_embedding_torch import RotaryEmbedding
|
| 12 |
+
from model.utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DenseFiLM(nn.Module):
|
| 16 |
+
"""Feature-wise linear modulation (FiLM) generator."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, embed_channels):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.embed_channels = embed_channels
|
| 21 |
+
self.block = nn.Sequential(
|
| 22 |
+
nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self, position):
|
| 26 |
+
pos_encoding = self.block(position)
|
| 27 |
+
pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
|
| 28 |
+
scale_shift = pos_encoding.chunk(2, dim=-1)
|
| 29 |
+
return scale_shift
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def featurewise_affine(x, scale_shift):
|
| 33 |
+
scale, shift = scale_shift
|
| 34 |
+
return (scale + 1) * x + shift
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TransformerEncoderLayer(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
d_model: int,
|
| 41 |
+
nhead: int,
|
| 42 |
+
dim_feedforward: int = 2048,
|
| 43 |
+
dropout: float = 0.1,
|
| 44 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 45 |
+
layer_norm_eps: float = 1e-5,
|
| 46 |
+
batch_first: bool = False,
|
| 47 |
+
norm_first: bool = True,
|
| 48 |
+
device=None,
|
| 49 |
+
dtype=None,
|
| 50 |
+
rotary=None,
|
| 51 |
+
) -> None:
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.self_attn = nn.MultiheadAttention(
|
| 54 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
| 55 |
+
)
|
| 56 |
+
# Implementation of Feedforward model
|
| 57 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 58 |
+
self.dropout = nn.Dropout(dropout)
|
| 59 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 60 |
+
|
| 61 |
+
self.norm_first = norm_first
|
| 62 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 63 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 64 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 65 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 66 |
+
self.activation = activation
|
| 67 |
+
|
| 68 |
+
self.rotary = rotary
|
| 69 |
+
self.use_rotary = rotary is not None
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self,
|
| 73 |
+
src: Tensor,
|
| 74 |
+
src_mask: Optional[Tensor] = None,
|
| 75 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 76 |
+
) -> Tensor:
|
| 77 |
+
x = src
|
| 78 |
+
if self.norm_first:
|
| 79 |
+
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
| 80 |
+
x = x + self._ff_block(self.norm2(x))
|
| 81 |
+
else:
|
| 82 |
+
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
| 83 |
+
x = self.norm2(x + self._ff_block(x))
|
| 84 |
+
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
# self-attention block
|
| 88 |
+
def _sa_block(
|
| 89 |
+
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
|
| 90 |
+
) -> Tensor:
|
| 91 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
| 92 |
+
x = self.self_attn(
|
| 93 |
+
qk,
|
| 94 |
+
qk,
|
| 95 |
+
x,
|
| 96 |
+
attn_mask=attn_mask,
|
| 97 |
+
key_padding_mask=key_padding_mask,
|
| 98 |
+
need_weights=False,
|
| 99 |
+
)[0]
|
| 100 |
+
return self.dropout1(x)
|
| 101 |
+
|
| 102 |
+
# feed forward block
|
| 103 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 104 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 105 |
+
return self.dropout2(x)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class FiLMTransformerDecoderLayer(nn.Module):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
d_model: int,
|
| 112 |
+
nhead: int,
|
| 113 |
+
dim_feedforward=2048,
|
| 114 |
+
dropout=0.1,
|
| 115 |
+
activation=F.relu,
|
| 116 |
+
layer_norm_eps=1e-5,
|
| 117 |
+
batch_first=False,
|
| 118 |
+
norm_first=True,
|
| 119 |
+
device=None,
|
| 120 |
+
dtype=None,
|
| 121 |
+
rotary=None,
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.self_attn = nn.MultiheadAttention(
|
| 125 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
| 126 |
+
)
|
| 127 |
+
self.multihead_attn = nn.MultiheadAttention(
|
| 128 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
| 129 |
+
)
|
| 130 |
+
# Feedforward
|
| 131 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| 132 |
+
self.dropout = nn.Dropout(dropout)
|
| 133 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
| 134 |
+
|
| 135 |
+
self.norm_first = norm_first
|
| 136 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 137 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 138 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 139 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 140 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 141 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 142 |
+
self.activation = activation
|
| 143 |
+
|
| 144 |
+
self.film1 = DenseFiLM(d_model)
|
| 145 |
+
self.film2 = DenseFiLM(d_model)
|
| 146 |
+
self.film3 = DenseFiLM(d_model)
|
| 147 |
+
|
| 148 |
+
self.rotary = rotary
|
| 149 |
+
self.use_rotary = rotary is not None
|
| 150 |
+
|
| 151 |
+
# x, cond, t
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
tgt,
|
| 155 |
+
memory,
|
| 156 |
+
t,
|
| 157 |
+
tgt_mask=None,
|
| 158 |
+
memory_mask=None,
|
| 159 |
+
tgt_key_padding_mask=None,
|
| 160 |
+
memory_key_padding_mask=None,
|
| 161 |
+
):
|
| 162 |
+
x = tgt
|
| 163 |
+
if self.norm_first:
|
| 164 |
+
# self-attention -> film -> residual
|
| 165 |
+
x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
|
| 166 |
+
x = x + featurewise_affine(x_1, self.film1(t))
|
| 167 |
+
# cross-attention -> film -> residual
|
| 168 |
+
x_2 = self._mha_block(
|
| 169 |
+
self.norm2(x), memory, memory_mask, memory_key_padding_mask
|
| 170 |
+
)
|
| 171 |
+
x = x + featurewise_affine(x_2, self.film2(t))
|
| 172 |
+
# feedforward -> film -> residual
|
| 173 |
+
x_3 = self._ff_block(self.norm3(x))
|
| 174 |
+
x = x + featurewise_affine(x_3, self.film3(t))
|
| 175 |
+
else:
|
| 176 |
+
x = self.norm1(
|
| 177 |
+
x
|
| 178 |
+
+ featurewise_affine(
|
| 179 |
+
self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
x = self.norm2(
|
| 183 |
+
x
|
| 184 |
+
+ featurewise_affine(
|
| 185 |
+
self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
|
| 186 |
+
self.film2(t),
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
# self-attention block
|
| 193 |
+
# qkv
|
| 194 |
+
def _sa_block(self, x, attn_mask, key_padding_mask):
|
| 195 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
| 196 |
+
x = self.self_attn(
|
| 197 |
+
qk,
|
| 198 |
+
qk,
|
| 199 |
+
x,
|
| 200 |
+
attn_mask=attn_mask,
|
| 201 |
+
key_padding_mask=key_padding_mask,
|
| 202 |
+
need_weights=False,
|
| 203 |
+
)[0]
|
| 204 |
+
return self.dropout1(x)
|
| 205 |
+
|
| 206 |
+
# multihead attention block
|
| 207 |
+
# qkv
|
| 208 |
+
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
|
| 209 |
+
q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
| 210 |
+
k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
|
| 211 |
+
x = self.multihead_attn(
|
| 212 |
+
q,
|
| 213 |
+
k,
|
| 214 |
+
mem,
|
| 215 |
+
attn_mask=attn_mask,
|
| 216 |
+
key_padding_mask=key_padding_mask,
|
| 217 |
+
need_weights=False,
|
| 218 |
+
)[0]
|
| 219 |
+
return self.dropout2(x)
|
| 220 |
+
|
| 221 |
+
# feed forward block
|
| 222 |
+
def _ff_block(self, x):
|
| 223 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 224 |
+
return self.dropout3(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DecoderLayerStack(nn.Module):
|
| 228 |
+
def __init__(self, stack):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.stack = stack
|
| 231 |
+
|
| 232 |
+
def forward(self, x, cond, t):
|
| 233 |
+
for layer in self.stack:
|
| 234 |
+
x = layer(x, cond, t)
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class SeqModel(nn.Module):
|
| 240 |
+
def __init__(self,
|
| 241 |
+
nfeats: int,
|
| 242 |
+
seq_len: int = 150, # 5 seconds, 30 fps
|
| 243 |
+
latent_dim: int = 256,
|
| 244 |
+
ff_size: int = 1024,
|
| 245 |
+
num_layers: int = 4,
|
| 246 |
+
num_heads: int = 4,
|
| 247 |
+
dropout: float = 0.1,
|
| 248 |
+
cond_feature_dim: int = 35,
|
| 249 |
+
activation: Callable[[Tensor], Tensor] = F.gelu,
|
| 250 |
+
use_rotary=True,
|
| 251 |
+
**kwargs
|
| 252 |
+
) -> None:
|
| 253 |
+
super().__init__()
|
| 254 |
+
|
| 255 |
+
self.network = nn.ModuleDict()
|
| 256 |
+
self.network['body_net'] = DanceDecoder(
|
| 257 |
+
nfeats=4+3+22*6,
|
| 258 |
+
seq_len=seq_len,
|
| 259 |
+
latent_dim=latent_dim,
|
| 260 |
+
ff_size=ff_size,
|
| 261 |
+
num_layers=num_layers,
|
| 262 |
+
num_heads=num_heads,
|
| 263 |
+
dropout=dropout,
|
| 264 |
+
cond_feature_dim=cond_feature_dim,
|
| 265 |
+
activation=activation
|
| 266 |
+
)
|
| 267 |
+
self.network['hand_net'] = DanceDecoder(
|
| 268 |
+
nfeats=30*6,
|
| 269 |
+
seq_len=seq_len,
|
| 270 |
+
latent_dim=latent_dim,
|
| 271 |
+
ff_size=ff_size,
|
| 272 |
+
num_layers=num_layers,
|
| 273 |
+
num_heads=num_heads,
|
| 274 |
+
dropout=dropout,
|
| 275 |
+
cond_feature_dim=35+139, # debug !
|
| 276 |
+
activation=activation
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def forward(self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0):
|
| 281 |
+
x_body_start = x[:,:,:4+135]
|
| 282 |
+
x_hand_start = x[:,:,4+135:]
|
| 283 |
+
body_output = self.network['body_net'](x_body_start, cond_embed, times, cond_drop_prob)
|
| 284 |
+
|
| 285 |
+
cond_embed = torch.cat([body_output, cond_embed], dim = -1)
|
| 286 |
+
hand_output = self.network['hand_net'](x_hand_start, cond_embed, times, cond_drop_prob)
|
| 287 |
+
|
| 288 |
+
output = torch.cat([body_output, hand_output], dim=-1)
|
| 289 |
+
return output
|
| 290 |
+
|
| 291 |
+
def guided_forward(self, x, cond_embed, times, guidance_weight):
|
| 292 |
+
unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
|
| 293 |
+
|
| 294 |
+
conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
|
| 295 |
+
return unc + (conditioned - unc) * guidance_weight
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class DanceDecoder(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
nfeats: int,
|
| 302 |
+
seq_len: int = 150, # 5 seconds, 30 fps
|
| 303 |
+
latent_dim: int = 256,
|
| 304 |
+
ff_size: int = 1024,
|
| 305 |
+
num_layers: int = 4,
|
| 306 |
+
num_heads: int = 4,
|
| 307 |
+
dropout: float = 0.1,
|
| 308 |
+
cond_feature_dim: int = 35,
|
| 309 |
+
activation: Callable[[Tensor], Tensor] = F.gelu,
|
| 310 |
+
use_rotary=True,
|
| 311 |
+
**kwargs
|
| 312 |
+
) -> None:
|
| 313 |
+
|
| 314 |
+
super().__init__()
|
| 315 |
+
|
| 316 |
+
output_feats = nfeats
|
| 317 |
+
|
| 318 |
+
# positional embeddings
|
| 319 |
+
self.rotary = None
|
| 320 |
+
self.abs_pos_encoding = nn.Identity()
|
| 321 |
+
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
|
| 322 |
+
if use_rotary:
|
| 323 |
+
self.rotary = RotaryEmbedding(dim=latent_dim)
|
| 324 |
+
else:
|
| 325 |
+
self.abs_pos_encoding = PositionalEncoding(
|
| 326 |
+
latent_dim, dropout, batch_first=True
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# time embedding processing
|
| 330 |
+
self.time_mlp = nn.Sequential(
|
| 331 |
+
SinusoidalPosEmb(latent_dim), # learned?
|
| 332 |
+
nn.Linear(latent_dim, latent_dim * 4),
|
| 333 |
+
nn.Mish(),
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
|
| 337 |
+
|
| 338 |
+
self.to_time_tokens = nn.Sequential(
|
| 339 |
+
nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
|
| 340 |
+
Rearrange("b (r d) -> b r d", r=2),
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# null embeddings for guidance dropout
|
| 344 |
+
self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
|
| 345 |
+
self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
|
| 346 |
+
|
| 347 |
+
self.norm_cond = nn.LayerNorm(latent_dim)
|
| 348 |
+
|
| 349 |
+
# input projection
|
| 350 |
+
self.input_projection = nn.Linear(nfeats, latent_dim)
|
| 351 |
+
self.cond_encoder = nn.Sequential()
|
| 352 |
+
for _ in range(2):
|
| 353 |
+
self.cond_encoder.append(
|
| 354 |
+
TransformerEncoderLayer(
|
| 355 |
+
d_model=latent_dim,
|
| 356 |
+
nhead=num_heads,
|
| 357 |
+
dim_feedforward=ff_size,
|
| 358 |
+
dropout=dropout,
|
| 359 |
+
activation=activation,
|
| 360 |
+
batch_first=True,
|
| 361 |
+
rotary=self.rotary,
|
| 362 |
+
)
|
| 363 |
+
)
|
| 364 |
+
# conditional projection
|
| 365 |
+
self.cond_projection = nn.Linear(cond_feature_dim, latent_dim) # debug cond_feature_dim
|
| 366 |
+
self.non_attn_cond_projection = nn.Sequential(
|
| 367 |
+
nn.LayerNorm(latent_dim),
|
| 368 |
+
nn.Linear(latent_dim, latent_dim),
|
| 369 |
+
nn.SiLU(),
|
| 370 |
+
nn.Linear(latent_dim, latent_dim),
|
| 371 |
+
)
|
| 372 |
+
# decoder
|
| 373 |
+
decoderstack = nn.ModuleList([])
|
| 374 |
+
for _ in range(num_layers):
|
| 375 |
+
decoderstack.append(
|
| 376 |
+
FiLMTransformerDecoderLayer(
|
| 377 |
+
latent_dim,
|
| 378 |
+
num_heads,
|
| 379 |
+
dim_feedforward=ff_size,
|
| 380 |
+
dropout=dropout,
|
| 381 |
+
activation=activation,
|
| 382 |
+
batch_first=True,
|
| 383 |
+
rotary=self.rotary,
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.seqTransDecoder = DecoderLayerStack(decoderstack)
|
| 388 |
+
|
| 389 |
+
self.final_layer = nn.Linear(latent_dim, output_feats)
|
| 390 |
+
|
| 391 |
+
def guided_forward(self, x, cond_embed, times, guidance_weight):
|
| 392 |
+
unc = self.forward(x, cond_embed, times, cond_drop_prob=1)
|
| 393 |
+
conditioned = self.forward(x, cond_embed, times, cond_drop_prob=0)
|
| 394 |
+
|
| 395 |
+
return unc + (conditioned - unc) * guidance_weight
|
| 396 |
+
|
| 397 |
+
def forward(
|
| 398 |
+
self, x: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
|
| 399 |
+
):
|
| 400 |
+
batch_size, device = x.shape[0], x.device
|
| 401 |
+
|
| 402 |
+
# project to latent space
|
| 403 |
+
x = self.input_projection(x)
|
| 404 |
+
# add the positional embeddings of the input sequence to provide temporal information
|
| 405 |
+
x = self.abs_pos_encoding(x)
|
| 406 |
+
|
| 407 |
+
# create music conditional embedding with conditional dropout
|
| 408 |
+
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
|
| 409 |
+
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
|
| 410 |
+
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
|
| 411 |
+
|
| 412 |
+
cond_tokens = self.cond_projection(cond_embed)
|
| 413 |
+
# encode tokens
|
| 414 |
+
cond_tokens = self.abs_pos_encoding(cond_tokens)
|
| 415 |
+
cond_tokens = self.cond_encoder(cond_tokens)
|
| 416 |
+
|
| 417 |
+
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
|
| 418 |
+
cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
|
| 419 |
+
|
| 420 |
+
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
|
| 421 |
+
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
|
| 422 |
+
|
| 423 |
+
# create the diffusion timestep embedding, add the extra music projection
|
| 424 |
+
t_hidden = self.time_mlp(times)
|
| 425 |
+
|
| 426 |
+
# project to attention and FiLM conditioning
|
| 427 |
+
t = self.to_time_cond(t_hidden)
|
| 428 |
+
t_tokens = self.to_time_tokens(t_hidden)
|
| 429 |
+
|
| 430 |
+
# FiLM conditioning
|
| 431 |
+
null_cond_hidden = self.null_cond_hidden.to(t.dtype)
|
| 432 |
+
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
|
| 433 |
+
t += cond_hidden
|
| 434 |
+
|
| 435 |
+
# cross-attention conditioning
|
| 436 |
+
c = torch.cat((cond_tokens, t_tokens), dim=-2)
|
| 437 |
+
cond_tokens = self.norm_cond(c)
|
| 438 |
+
|
| 439 |
+
# Pass through the transformer decoder
|
| 440 |
+
# attending to the conditional embedding
|
| 441 |
+
output = self.seqTransDecoder(x, cond_tokens, t)
|
| 442 |
+
|
| 443 |
+
output = self.final_layer(output)
|
| 444 |
+
return output
|
model/rotary_embedding_torch.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isfunction
|
| 2 |
+
from math import log, pi
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from torch import einsum, nn
|
| 7 |
+
|
| 8 |
+
# helper functions
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def exists(val):
|
| 12 |
+
return val is not None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def broadcat(tensors, dim=-1):
|
| 16 |
+
num_tensors = len(tensors)
|
| 17 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 18 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
| 19 |
+
shape_len = list(shape_lens)[0]
|
| 20 |
+
|
| 21 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 22 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 23 |
+
|
| 24 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 25 |
+
assert all(
|
| 26 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
| 27 |
+
), "invalid dimensions for broadcastable concatentation"
|
| 28 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 29 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 30 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 31 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 32 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 33 |
+
return torch.cat(tensors, dim=dim)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# rotary embedding helper functions
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def rotate_half(x):
|
| 40 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 41 |
+
x1, x2 = x.unbind(dim=-1)
|
| 42 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 43 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def apply_rotary_emb(freqs, t, start_index=0):
|
| 47 |
+
freqs = freqs.to(t)
|
| 48 |
+
rot_dim = freqs.shape[-1]
|
| 49 |
+
end_index = start_index + rot_dim
|
| 50 |
+
assert (
|
| 51 |
+
rot_dim <= t.shape[-1]
|
| 52 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 53 |
+
t_left, t, t_right = (
|
| 54 |
+
t[..., :start_index],
|
| 55 |
+
t[..., start_index:end_index],
|
| 56 |
+
t[..., end_index:],
|
| 57 |
+
)
|
| 58 |
+
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
| 59 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# learned rotation helpers
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 66 |
+
if exists(freq_ranges):
|
| 67 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
| 68 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
| 69 |
+
|
| 70 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
| 71 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# classes
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class RotaryEmbedding(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
dim,
|
| 81 |
+
custom_freqs=None,
|
| 82 |
+
freqs_for="lang",
|
| 83 |
+
theta=10000,
|
| 84 |
+
max_freq=10,
|
| 85 |
+
num_freqs=1,
|
| 86 |
+
learned_freq=False,
|
| 87 |
+
):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if exists(custom_freqs):
|
| 90 |
+
freqs = custom_freqs
|
| 91 |
+
elif freqs_for == "lang":
|
| 92 |
+
freqs = 1.0 / (
|
| 93 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 94 |
+
)
|
| 95 |
+
elif freqs_for == "pixel":
|
| 96 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 97 |
+
elif freqs_for == "constant":
|
| 98 |
+
freqs = torch.ones(num_freqs).float()
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 101 |
+
|
| 102 |
+
self.cache = dict()
|
| 103 |
+
|
| 104 |
+
if learned_freq:
|
| 105 |
+
self.freqs = nn.Parameter(freqs)
|
| 106 |
+
else:
|
| 107 |
+
self.register_buffer("freqs", freqs)
|
| 108 |
+
|
| 109 |
+
def rotate_queries_or_keys(self, t, seq_dim=-2):
|
| 110 |
+
device = t.device
|
| 111 |
+
seq_len = t.shape[seq_dim]
|
| 112 |
+
freqs = self.forward(
|
| 113 |
+
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
|
| 114 |
+
)
|
| 115 |
+
return apply_rotary_emb(freqs, t)
|
| 116 |
+
|
| 117 |
+
def forward(self, t, cache_key=None):
|
| 118 |
+
if exists(cache_key) and cache_key in self.cache:
|
| 119 |
+
return self.cache[cache_key]
|
| 120 |
+
|
| 121 |
+
if isfunction(t):
|
| 122 |
+
t = t()
|
| 123 |
+
|
| 124 |
+
freqs = self.freqs
|
| 125 |
+
|
| 126 |
+
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
| 127 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 128 |
+
|
| 129 |
+
if exists(cache_key):
|
| 130 |
+
self.cache[cache_key] = freqs
|
| 131 |
+
|
| 132 |
+
return freqs
|
model/utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange, reduce, repeat
|
| 6 |
+
from einops.layers.torch import Rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# absolute positional embedding used for vanilla transformer sequential data
|
| 11 |
+
class PositionalEncoding(nn.Module):
|
| 12 |
+
def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.batch_first = batch_first
|
| 15 |
+
|
| 16 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 17 |
+
|
| 18 |
+
pe = torch.zeros(max_len, d_model)
|
| 19 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
| 20 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
|
| 21 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 22 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 23 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 24 |
+
|
| 25 |
+
self.register_buffer("pe", pe)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
if self.batch_first:
|
| 29 |
+
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
|
| 30 |
+
else:
|
| 31 |
+
x = x + self.pe[: x.shape[0], :]
|
| 32 |
+
return self.dropout(x)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# very similar positional embedding used for diffusion timesteps
|
| 36 |
+
class SinusoidalPosEmb(nn.Module):
|
| 37 |
+
def __init__(self, dim):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.dim = dim
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
device = x.device
|
| 43 |
+
half_dim = self.dim // 2
|
| 44 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 45 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 46 |
+
emb = x[:, None] * emb[None, :]
|
| 47 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# dropout mask
|
| 52 |
+
def prob_mask_like(shape, prob, device):
|
| 53 |
+
if prob == 1:
|
| 54 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 55 |
+
elif prob == 0:
|
| 56 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 57 |
+
else:
|
| 58 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def extract(a, t, x_shape):
|
| 62 |
+
b, *_ = t.shape
|
| 63 |
+
out = a.gather(-1, t)
|
| 64 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_beta_schedule(
|
| 68 |
+
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
| 69 |
+
):
|
| 70 |
+
if schedule == "linear":
|
| 71 |
+
betas = (
|
| 72 |
+
torch.linspace(
|
| 73 |
+
linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
|
| 74 |
+
)
|
| 75 |
+
** 2
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
elif schedule == "cosine":
|
| 79 |
+
timesteps = (
|
| 80 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
| 81 |
+
)
|
| 82 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 83 |
+
alphas = torch.cos(alphas).pow(2)
|
| 84 |
+
alphas = alphas / alphas[0]
|
| 85 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
| 86 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 87 |
+
|
| 88 |
+
elif schedule == "sqrt_linear":
|
| 89 |
+
betas = torch.linspace(
|
| 90 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
| 91 |
+
)
|
| 92 |
+
elif schedule == "sqrt":
|
| 93 |
+
betas = (
|
| 94 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
| 95 |
+
** 0.5
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 99 |
+
return betas.numpy()
|
render.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import os
|
| 6 |
+
# os.environ["PYOPENGL_PLATFORM"] = "osmesa" # Not available on macOS
|
| 7 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from smplx import SMPL, SMPLX, SMPLH
|
| 10 |
+
import pyrender
|
| 11 |
+
import trimesh
|
| 12 |
+
import subprocess
|
| 13 |
+
import pickle
|
| 14 |
+
from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle,
|
| 15 |
+
matrix_to_quaternion, matrix_to_rotation_6d,
|
| 16 |
+
quaternion_to_matrix, rotation_6d_to_matrix)
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('.')
|
| 20 |
+
import argparse
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def quat_to_6v(q):
|
| 24 |
+
assert q.shape[-1] == 4
|
| 25 |
+
mat = quaternion_to_matrix(q)
|
| 26 |
+
mat = matrix_to_rotation_6d(mat)
|
| 27 |
+
return mat
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def quat_from_6v(q):
|
| 31 |
+
assert q.shape[-1] == 6
|
| 32 |
+
mat = rotation_6d_to_matrix(q)
|
| 33 |
+
quat = matrix_to_quaternion(mat)
|
| 34 |
+
return quat
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def ax_to_6v(q):
|
| 38 |
+
assert q.shape[-1] == 3
|
| 39 |
+
mat = axis_angle_to_matrix(q)
|
| 40 |
+
mat = matrix_to_rotation_6d(mat)
|
| 41 |
+
return mat
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def ax_from_6v(q):
|
| 45 |
+
assert q.shape[-1] == 6
|
| 46 |
+
mat = rotation_6d_to_matrix(q)
|
| 47 |
+
ax = matrix_to_axis_angle(mat)
|
| 48 |
+
return ax
|
| 49 |
+
|
| 50 |
+
class MovieMaker():
|
| 51 |
+
def __init__(self, save_path) -> None:
|
| 52 |
+
self.mag = 2
|
| 53 |
+
self.eyes = np.array([[3,-3,2], [0,0,-2], [0,0,4], [-8,-8,1], [0,-2,4], [0,2,4]])
|
| 54 |
+
self.centers = np.array([[0,0,0],[0,0,0],[0,0.5,0],[0,0,-1], [0,0.5,0], [0,0.5,0]])
|
| 55 |
+
self.ups = np.array([[0,0,1],[0,1,0],[0,1,0],[0,0,-1], [0,1,0], [0,1,0]])
|
| 56 |
+
self.save_path = save_path
|
| 57 |
+
self.fps = args.fps
|
| 58 |
+
self.img_size = (1200,1200)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# SMPLH_path = "assets/smpl_model/smplh/SMPLH_MALE.pkl"
|
| 62 |
+
# SMPL_path = "assets/smpl_model/smpl/SMPL_MALE.pkl"
|
| 63 |
+
SMPLX_path = "assets/smpl_model/smplx/SMPLX_NEUTRAL.npz"
|
| 64 |
+
trimesh_path = 'assets/NORMAL_new.obj'
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# self.smplh = SMPLH(SMPLH_path, use_pca=False, flat_hand_mean=True)
|
| 68 |
+
# self.smplh.to(f'cuda:{args.gpu}').eval()
|
| 69 |
+
|
| 70 |
+
# self.smpl = SMPL(SMPL_path)
|
| 71 |
+
# self.smpl.to(f'cuda:{args.gpu}').eval()
|
| 72 |
+
|
| 73 |
+
self.smplx = SMPLX(SMPLX_path, use_pca=False, flat_hand_mean=True).eval()
|
| 74 |
+
_device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 75 |
+
self.smplx.to(_device).eval()
|
| 76 |
+
|
| 77 |
+
self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0])
|
| 78 |
+
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
|
| 79 |
+
camera_pose = look_at(self.eyes[5], self.centers[5], self.ups[5]) # 2
|
| 80 |
+
self.scene.add(camera, pose=camera_pose)
|
| 81 |
+
light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)
|
| 82 |
+
self.scene.add(light, pose=camera_pose)
|
| 83 |
+
self.r = pyrender.OffscreenRenderer(self.img_size[0], self.img_size[1])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# self.mesh = trimesh.load(trimesh_path)
|
| 87 |
+
# floor_mesh = pyrender.Mesh.from_trimesh(self.mesh)
|
| 88 |
+
# floor_node = self.scene.add(floor_mesh)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def save_video(self, save_path, color_list):
|
| 92 |
+
# save_path = os.path.join(save_path,'move.mp4')
|
| 93 |
+
f = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
| 94 |
+
videowriter = cv2.VideoWriter(save_path,f,self.fps,self.img_size)
|
| 95 |
+
for i in range(len(color_list)):
|
| 96 |
+
videowriter.write(color_list[i][:,:,::-1])
|
| 97 |
+
videowriter.release()
|
| 98 |
+
|
| 99 |
+
def get_imgs(self, motion):
|
| 100 |
+
meshes = self.motion2mesh(motion)
|
| 101 |
+
imgs = self.render_imgs(meshes)
|
| 102 |
+
return np.concatenate(imgs, axis=1)
|
| 103 |
+
|
| 104 |
+
def motion2mesh(self, motion):
|
| 105 |
+
if args.mode == "smpl":
|
| 106 |
+
output = self.smpl.forward(
|
| 107 |
+
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
|
| 108 |
+
transl = motion[:,:3],
|
| 109 |
+
global_orient = motion[:,3:6],
|
| 110 |
+
body_pose = torch.cat([motion[:,6:69], motion[:,69:72], motion[:,114:117]], dim=1)
|
| 111 |
+
)
|
| 112 |
+
elif args.mode == "smplh":
|
| 113 |
+
output = self.smplh.forward(
|
| 114 |
+
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
|
| 115 |
+
# transl = motion[:,:3],
|
| 116 |
+
transl = torch.tensor([[0,0,-1]]).expand(motion.shape[0],-1).to(motion.device) ,
|
| 117 |
+
global_orient = motion[:,3:6],
|
| 118 |
+
body_pose = motion[:,6:69],
|
| 119 |
+
left_hand_pose = motion[:,69:114],
|
| 120 |
+
right_hand_pose = motion[:,114:159],
|
| 121 |
+
)
|
| 122 |
+
elif args.mode == "smplx":
|
| 123 |
+
output = self.smplx.forward(
|
| 124 |
+
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
|
| 125 |
+
# transl = motion[:,:3],
|
| 126 |
+
transl = motion[:,:3],
|
| 127 |
+
global_orient = motion[:,3:6],
|
| 128 |
+
body_pose = motion[:,6:69],
|
| 129 |
+
jaw_pose = torch.zeros([motion.shape[0], 3]).to(motion),
|
| 130 |
+
leye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
|
| 131 |
+
reye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
|
| 132 |
+
left_hand_pose = motion[:,69:69+45],
|
| 133 |
+
right_hand_pose = motion[:,69+45:],
|
| 134 |
+
expression= torch.zeros([motion.shape[0], 10]).to(motion),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
meshes = []
|
| 138 |
+
for i in range(output.vertices.shape[0]):
|
| 139 |
+
if args.mode == 'smplh':
|
| 140 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
|
| 141 |
+
elif args.mode == 'smplx':
|
| 142 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
|
| 143 |
+
elif args.mode == 'smpl':
|
| 144 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
|
| 145 |
+
# mesh.export(os.path.join(self.save_path, f'{i}.obj'))
|
| 146 |
+
meshes.append(mesh)
|
| 147 |
+
|
| 148 |
+
return meshes
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def render_multi_view(self, meshes, music_file, tab='', eyes=None, centers=None, ups=None, views=1):
|
| 152 |
+
if eyes and centers and ups:
|
| 153 |
+
assert eyes.shape == centers.shape == ups.shape
|
| 154 |
+
else:
|
| 155 |
+
eyes = self.eyes
|
| 156 |
+
centers = self.centers
|
| 157 |
+
ups = self.ups
|
| 158 |
+
|
| 159 |
+
for i in range(views):
|
| 160 |
+
color_list = self.render_single_view(meshes, eyes[1], centers[1], ups[1])
|
| 161 |
+
movie_file = os.path.join(self.save_path, tab + '-' + str(i) + '.mp4')
|
| 162 |
+
output_file = os.path.join(self.save_path, tab + '-' + str(i) + '-music.mp4')
|
| 163 |
+
self.save_video(movie_file, color_list)
|
| 164 |
+
if music_file is not None:
|
| 165 |
+
subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
|
| 166 |
+
else:
|
| 167 |
+
subprocess.run(['ffmpeg','-i',movie_file,output_file])
|
| 168 |
+
# if music_file is not None:
|
| 169 |
+
# subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
|
| 170 |
+
# else:
|
| 171 |
+
# subprocess.run(['ffmpeg','-i',movie_file,output_file])
|
| 172 |
+
os.remove(movie_file)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def render_single_view(self, meshes):
|
| 178 |
+
num = len(meshes)
|
| 179 |
+
color_list = []
|
| 180 |
+
for i in tqdm(range(num)):
|
| 181 |
+
mesh_nodes = []
|
| 182 |
+
for mesh in meshes[i]:
|
| 183 |
+
render_mesh = pyrender.Mesh.from_trimesh(mesh)
|
| 184 |
+
mesh_node = self.scene.add(render_mesh)
|
| 185 |
+
mesh_nodes.append(mesh_node)
|
| 186 |
+
color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
|
| 187 |
+
color = color.copy()
|
| 188 |
+
color_list.append(color)
|
| 189 |
+
for mesh_node in mesh_nodes:
|
| 190 |
+
self.scene.remove_node(mesh_node)
|
| 191 |
+
return color_list
|
| 192 |
+
|
| 193 |
+
def render_imgs(self, meshes):
|
| 194 |
+
colors = []
|
| 195 |
+
for mesh in meshes:
|
| 196 |
+
render_mesh = pyrender.Mesh.from_trimesh(mesh)
|
| 197 |
+
mesh_node = self.scene.add(render_mesh)
|
| 198 |
+
color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
|
| 199 |
+
colors.append(color)
|
| 200 |
+
self.scene.remove_node(mesh_node)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
return colors
|
| 204 |
+
# cv2.imwrite(os.path.join(self.save_path, 'test.jpg'), color[:,:,::-1])
|
| 205 |
+
|
| 206 |
+
def run(self, seq_rot, music_file=None, tab='', save_pt=False):
|
| 207 |
+
if isinstance(seq_rot, np.ndarray):
|
| 208 |
+
seq_rot = torch.tensor(seq_rot, dtype=torch.float32, device="mps" if torch.backends.mps.is_available() else "cpu")
|
| 209 |
+
|
| 210 |
+
if save_pt:
|
| 211 |
+
torch.save(seq_rot.detach().cpu(), os.path.join(self.save_path, tab +'_pose.pt'))
|
| 212 |
+
|
| 213 |
+
B, D = seq_rot.shape
|
| 214 |
+
if args.mode == "smpl":
|
| 215 |
+
print("using smpl!!!")
|
| 216 |
+
output = self.smpl.forward(
|
| 217 |
+
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
|
| 218 |
+
transl = seq_rot[:,:3],
|
| 219 |
+
global_orient = seq_rot[:,3:6],
|
| 220 |
+
body_pose = torch.cat([seq_rot[:,6:69], seq_rot[:,69:72], seq_rot[:,114:117]], dim=1)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
elif args.mode == "smplh":
|
| 224 |
+
print("using smplh!!!")
|
| 225 |
+
output = self.smplh.forward(
|
| 226 |
+
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
|
| 227 |
+
transl = seq_rot[:,:3],
|
| 228 |
+
global_orient = seq_rot[:,3:6],
|
| 229 |
+
body_pose = seq_rot[:,6:69],
|
| 230 |
+
left_hand_pose = seq_rot[:,69:114], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), # seq_rot[:,69:114],
|
| 231 |
+
right_hand_pose = seq_rot[:,114:], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), #
|
| 232 |
+
expression = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
elif args.mode == "smplx":
|
| 236 |
+
output = self.smplx.forward(
|
| 237 |
+
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
|
| 238 |
+
# transl = motion[:,:3],
|
| 239 |
+
transl = seq_rot[:,:3],
|
| 240 |
+
global_orient = seq_rot[:,3:6],
|
| 241 |
+
body_pose = seq_rot[:,6:69],
|
| 242 |
+
jaw_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
|
| 243 |
+
leye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
|
| 244 |
+
reye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
|
| 245 |
+
left_hand_pose = seq_rot[:,69:69+45],
|
| 246 |
+
right_hand_pose = seq_rot[:,69+45:],
|
| 247 |
+
expression= torch.zeros([seq_rot.shape[0], 10]).to(seq_rot),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
N, V, DD = output.vertices.shape # 150, 6890, 3
|
| 251 |
+
vertices = output.vertices.reshape((B, -1, V, DD)) # # 150, 1, 6890, 3
|
| 252 |
+
|
| 253 |
+
meshes = []
|
| 254 |
+
for i in range(B):
|
| 255 |
+
# if int(i) > 20:
|
| 256 |
+
# break
|
| 257 |
+
view = []
|
| 258 |
+
for v in vertices[i]:
|
| 259 |
+
# vertices[:,2] *= -1
|
| 260 |
+
if args.mode == 'smplh':
|
| 261 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
|
| 262 |
+
elif args.mode == 'smplx':
|
| 263 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
|
| 264 |
+
elif args.mode == 'smpl':
|
| 265 |
+
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
|
| 266 |
+
# mesh.export(os.path.join(self.save_path, 'test.obj'))
|
| 267 |
+
view.append(mesh)
|
| 268 |
+
meshes.append(view)
|
| 269 |
+
|
| 270 |
+
color_list = self.render_single_view(meshes)
|
| 271 |
+
movie_file = os.path.join(self.save_path, tab + 'tmp.mp4')
|
| 272 |
+
output_file = os.path.join(self.save_path, tab + 'z.mp4')
|
| 273 |
+
self.save_video(movie_file, color_list)
|
| 274 |
+
if music_file is not None:
|
| 275 |
+
subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
|
| 276 |
+
else:
|
| 277 |
+
subprocess.run(['ffmpeg','-i',movie_file,output_file])
|
| 278 |
+
# if music_file is not None:
|
| 279 |
+
# subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
|
| 280 |
+
# else:
|
| 281 |
+
# subprocess.run(['ffmpeg','-i',movie_file,output_file])
|
| 282 |
+
os.remove(movie_file)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def look_at(eye, center, up):
|
| 288 |
+
front = eye - center
|
| 289 |
+
front = front / np.linalg.norm(front)
|
| 290 |
+
right = np.cross(up, front)
|
| 291 |
+
right = right/ np.linalg.norm(right)
|
| 292 |
+
up_new = np.cross(front, right)
|
| 293 |
+
camera_pose = np.eye(4)
|
| 294 |
+
camera_pose[:3,:3] = np.stack([right, up_new, front]).transpose()
|
| 295 |
+
camera_pose[:3,3] = eye
|
| 296 |
+
return camera_pose
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def motion_data_load_process(motionfile):
|
| 300 |
+
if motionfile.split(".")[-1] == "pkl":
|
| 301 |
+
pkl_data = pickle.load(open(motionfile, "rb"))
|
| 302 |
+
smpl_poses = pkl_data["smpl_poses"]
|
| 303 |
+
modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1)
|
| 304 |
+
if modata.shape[1] == 69:
|
| 305 |
+
hand_zeros = np.zeros([modata.shape[0], 90], dtype=np.float32)
|
| 306 |
+
modata = np.concatenate((modata, hand_zeros), axis=1)
|
| 307 |
+
assert modata.shape[1] == 159
|
| 308 |
+
modata[:, 1] = modata[:, 1] + 1.3
|
| 309 |
+
return modata
|
| 310 |
+
elif motionfile.split(".")[-1] == "npy":
|
| 311 |
+
modata = np.load(motionfile)
|
| 312 |
+
print("modata.shape", modata.shape)
|
| 313 |
+
if modata.shape[-1] == 315: # first 3-dim is root translation
|
| 314 |
+
print("modata.shape is:", modata.shape)
|
| 315 |
+
rot6d = torch.from_numpy(modata[:,3:])
|
| 316 |
+
T,C = rot6d.shape
|
| 317 |
+
rot6d = rot6d.reshape(-1,6)
|
| 318 |
+
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
|
| 319 |
+
modata = np.concatenate((modata[:,:3], axis), axis=1)
|
| 320 |
+
print("modata.shape is:", modata.shape)
|
| 321 |
+
elif modata.shape[-1] == 319:
|
| 322 |
+
print("modata.shape is:", modata.shape)
|
| 323 |
+
modata = modata[:,4:]
|
| 324 |
+
rot6d = torch.from_numpy(modata[:,3:])
|
| 325 |
+
T,C = rot6d.shape
|
| 326 |
+
rot6d = rot6d.reshape(-1,6)
|
| 327 |
+
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
|
| 328 |
+
modata = np.concatenate((modata[:,:3], axis), axis=1)
|
| 329 |
+
print("modata.shape is:", modata.shape)
|
| 330 |
+
elif modata.shape[-1] == 168:
|
| 331 |
+
modata = np.concatenate( [modata[:,:21*3+1], modata[:,25*3:]] , axis=1)
|
| 332 |
+
elif modata.shape[-1] == 159:
|
| 333 |
+
print("modata.shape is:", modata.shape)
|
| 334 |
+
print("modata.shape is:", modata.shape)
|
| 335 |
+
elif modata.shape[-1] == 135:
|
| 336 |
+
print("modata.shape is:", modata.shape)
|
| 337 |
+
if len(modata.shape) == 3 and modata.shape[0] ==1:
|
| 338 |
+
modata = modata.squeeze(0)
|
| 339 |
+
rot6d = torch.from_numpy(modata[:,3:])
|
| 340 |
+
T,C = rot6d.shape
|
| 341 |
+
rot6d = rot6d.reshape(-1,6)
|
| 342 |
+
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
|
| 343 |
+
hand_zeros = torch.zeros([T, 90]).to(rot6d).detach().cpu().numpy()
|
| 344 |
+
modata = np.concatenate((modata[:,:3], axis, hand_zeros), axis=1)
|
| 345 |
+
print("modata.shape is:", modata.shape)
|
| 346 |
+
else:
|
| 347 |
+
raise("shape error!")
|
| 348 |
+
|
| 349 |
+
modata[:, 1] = modata[:, 1] + 1.3
|
| 350 |
+
return modata
|
| 351 |
+
|
| 352 |
+
if __name__ == '__main__':
|
| 353 |
+
parser = argparse.ArgumentParser()
|
| 354 |
+
parser.add_argument("--gpu", type=str, default="2")
|
| 355 |
+
parser.add_argument("--modir", type=str, default="")
|
| 356 |
+
parser.add_argument("--mode", type=str, default="smplx", choices=['smpl','smplh','smplx'])
|
| 357 |
+
parser.add_argument("--fps", type=int, default=30)
|
| 358 |
+
parser.add_argument("--save_path", type=str, default=None)
|
| 359 |
+
args = parser.parse_args()
|
| 360 |
+
print(args.gpu)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
motion_dir = args.modir
|
| 364 |
+
if args.save_path is not None:
|
| 365 |
+
save_path = args.save_path
|
| 366 |
+
if not os.path.exists(save_path):
|
| 367 |
+
os.makedirs(save_path)
|
| 368 |
+
else:
|
| 369 |
+
save_path = os.path.join(motion_dir, 'video')
|
| 370 |
+
os.makedirs(save_path, exist_ok=True)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
music_dir = "experiments/DanceDiffuse_module/debug--0517_Norm_512len_315_transloss/val1640/samples_2023-05-17-20-54-05"
|
| 374 |
+
for file in os.listdir(motion_dir):
|
| 375 |
+
if file[-3:] in ["npy", "pkl"]:
|
| 376 |
+
|
| 377 |
+
# if there have exist rendered video, continue
|
| 378 |
+
flag = False
|
| 379 |
+
for exists_file in os.listdir(save_path):
|
| 380 |
+
if file[:-4] in exists_file:
|
| 381 |
+
flag = True
|
| 382 |
+
break
|
| 383 |
+
else:
|
| 384 |
+
flag = False
|
| 385 |
+
if flag:
|
| 386 |
+
print("exist", file)
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
print(file)
|
| 390 |
+
motion_file = os.path.join(motion_dir, file)
|
| 391 |
+
visualizer = MovieMaker(save_path=save_path)
|
| 392 |
+
modata = motion_data_load_process(motion_file)
|
| 393 |
+
visualizer.run(modata, tab=os.path.basename(motion_file).split(".")[0], music_file=None)
|
| 394 |
+
|
| 395 |
+
print('done')
|
smplx_neu_J_1.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa80c6e1b28e9d43470f250a2d168bacf2cb0cc5e8688ff2e48e71f4db13ba6c
|
| 3 |
+
size 788
|
teaser/teaser.png
ADDED
|
Git LFS Details
|
test.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from functools import cmp_to_key
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import sys
|
| 6 |
+
from tempfile import TemporaryDirectory
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
# import jukemirlib
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import librosa
|
| 14 |
+
import librosa as lr
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
|
| 17 |
+
from args import FineDance_parse_test_opt
|
| 18 |
+
from train_seq import EDGE
|
| 19 |
+
# from data.audio_extraction.jukebox_features import extract as juke_extract
|
| 20 |
+
|
| 21 |
+
def slice_audio(audio_file, stride, length, out_dir):
|
| 22 |
+
# stride, length in seconds
|
| 23 |
+
audio, sr = lr.load(audio_file, sr=None)
|
| 24 |
+
file_name = os.path.splitext(os.path.basename(audio_file))[0]
|
| 25 |
+
start_idx = 0
|
| 26 |
+
idx = 0
|
| 27 |
+
window = int(length * sr)
|
| 28 |
+
stride_step = int(stride * sr)
|
| 29 |
+
while start_idx <= len(audio) - window:
|
| 30 |
+
audio_slice = audio[start_idx : start_idx + window]
|
| 31 |
+
sf.write(f"{out_dir}/{file_name}_slice{idx}.wav", audio_slice, sr)
|
| 32 |
+
start_idx += stride_step
|
| 33 |
+
idx += 1
|
| 34 |
+
return idx
|
| 35 |
+
|
| 36 |
+
def extract(fpath):
|
| 37 |
+
FPS = 30
|
| 38 |
+
HOP_LENGTH = 512
|
| 39 |
+
SR = FPS * HOP_LENGTH
|
| 40 |
+
EPS = 1e-6
|
| 41 |
+
|
| 42 |
+
data, _ = librosa.load(fpath, sr=SR)
|
| 43 |
+
envelope = librosa.onset.onset_strength(y=data, sr=SR) # (seq_len,)
|
| 44 |
+
mfcc = librosa.feature.mfcc(y=data, sr=SR, n_mfcc=20).T # (seq_len, 20)
|
| 45 |
+
chroma = librosa.feature.chroma_cens(
|
| 46 |
+
y=data, sr=SR, hop_length=HOP_LENGTH, n_chroma=12
|
| 47 |
+
).T # (seq_len, 12)
|
| 48 |
+
|
| 49 |
+
peak_idxs = librosa.onset.onset_detect(
|
| 50 |
+
onset_envelope=envelope.flatten(), sr=SR, hop_length=HOP_LENGTH
|
| 51 |
+
)
|
| 52 |
+
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 53 |
+
peak_onehot[peak_idxs] = 1.0 # (seq_len,)
|
| 54 |
+
|
| 55 |
+
start_bpm = lr.beat.tempo(y=lr.load(fpath)[0])[0]
|
| 56 |
+
|
| 57 |
+
tempo, beat_idxs = librosa.beat.beat_track(
|
| 58 |
+
onset_envelope=envelope,
|
| 59 |
+
sr=SR,
|
| 60 |
+
hop_length=HOP_LENGTH,
|
| 61 |
+
start_bpm=start_bpm,
|
| 62 |
+
tightness=100,
|
| 63 |
+
)
|
| 64 |
+
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
| 65 |
+
beat_onehot[beat_idxs] = 1.0 # (seq_len,)
|
| 66 |
+
|
| 67 |
+
audio_feature = np.concatenate(
|
| 68 |
+
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
| 69 |
+
axis=-1,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# chop to ensure exact shape
|
| 73 |
+
audio_feature = audio_feature[:4 * FPS]
|
| 74 |
+
return audio_feature
|
| 75 |
+
|
| 76 |
+
# sort filenames that look like songname_slice{number}.ext
|
| 77 |
+
key_func = lambda x: int(os.path.splitext(x)[0].split("_")[-1].split("slice")[-1])
|
| 78 |
+
# test_list = ["063", "132", "143", "036", "098", "198", "130", "012", "211", "193", "179", "065", "137", "161", "092", "120", "037", "109", "204", "144"]
|
| 79 |
+
test_list = ["063", "144"]
|
| 80 |
+
|
| 81 |
+
def stringintcmp_(a, b):
|
| 82 |
+
aa, bb = "".join(a.split("_")[:-1]), "".join(b.split("_")[:-1])
|
| 83 |
+
ka, kb = key_func(a), key_func(b)
|
| 84 |
+
if aa < bb:
|
| 85 |
+
return -1
|
| 86 |
+
if aa > bb:
|
| 87 |
+
return 1
|
| 88 |
+
if ka < kb:
|
| 89 |
+
return -1
|
| 90 |
+
if ka > kb:
|
| 91 |
+
return 1
|
| 92 |
+
return 0
|
| 93 |
+
|
| 94 |
+
stringintkey = cmp_to_key(stringintcmp_)
|
| 95 |
+
stride_ = 60/30
|
| 96 |
+
|
| 97 |
+
def test(opt):
|
| 98 |
+
feature_func = extract
|
| 99 |
+
sample_length = opt.out_length
|
| 100 |
+
sample_size = int(sample_length / stride_) - 1
|
| 101 |
+
temp_dir_list = []
|
| 102 |
+
all_cond = []
|
| 103 |
+
all_filenames = []
|
| 104 |
+
if opt.use_cached_features: # default is false
|
| 105 |
+
print("Using precomputed features")
|
| 106 |
+
# all subdirectories
|
| 107 |
+
dir_list = glob.glob(os.path.join(opt.feature_cache_dir, "*/"))
|
| 108 |
+
for dir in dir_list:
|
| 109 |
+
file_list = sorted(glob.glob(f"{dir}/*.wav"), key=stringintkey)
|
| 110 |
+
juke_file_list = sorted(glob.glob(f"{dir}/*.npy"), key=stringintkey)
|
| 111 |
+
assert len(file_list) == len(juke_file_list)
|
| 112 |
+
|
| 113 |
+
# random chunk after sanity check
|
| 114 |
+
rand_idx = random.randint(0, len(file_list) - sample_size)
|
| 115 |
+
file_list = file_list[rand_idx : rand_idx + sample_size]
|
| 116 |
+
juke_file_list = juke_file_list[rand_idx : rand_idx + sample_size]
|
| 117 |
+
cond_list = [np.load(x) for x in juke_file_list]
|
| 118 |
+
all_filenames.append(file_list)
|
| 119 |
+
all_cond.append(torch.from_numpy(np.array(cond_list)))
|
| 120 |
+
else:
|
| 121 |
+
print("Computing features for input music")
|
| 122 |
+
for wav_file in glob.glob(os.path.join(opt.music_dir, "*.wav")):
|
| 123 |
+
songname = os.path.splitext(os.path.basename(wav_file))[0]
|
| 124 |
+
# create temp folder (or use the cache folder if specified)
|
| 125 |
+
if True: # songname in test_list:
|
| 126 |
+
if opt.cache_features:
|
| 127 |
+
save_dir = os.path.join(opt.feature_cache_dir, songname)
|
| 128 |
+
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
| 129 |
+
dirname = save_dir
|
| 130 |
+
else:
|
| 131 |
+
temp_dir = TemporaryDirectory()
|
| 132 |
+
print("temp_dir is", temp_dir)
|
| 133 |
+
temp_dir_list.append(temp_dir)
|
| 134 |
+
dirname = temp_dir.name
|
| 135 |
+
# slice the audio file
|
| 136 |
+
print(f"Slicing {wav_file}")
|
| 137 |
+
slice_audio(wav_file, 60/30, 120/30, dirname)
|
| 138 |
+
file_list = sorted(glob.glob(f"{dirname}/*.wav"), key=stringintkey)
|
| 139 |
+
# randomly sample a chunk of length at most sample_size
|
| 140 |
+
rand_idx = random.randint(0, len(file_list) - sample_size)
|
| 141 |
+
cond_list = []
|
| 142 |
+
# generate juke representations
|
| 143 |
+
print(f"Computing features for {wav_file}")
|
| 144 |
+
for idx, file in enumerate(tqdm(file_list)):
|
| 145 |
+
# if not caching then only calculate for the interested range
|
| 146 |
+
if (not opt.cache_features) and (not (rand_idx <= idx < rand_idx + sample_size)):
|
| 147 |
+
continue
|
| 148 |
+
# audio = jukemirlib.load_audio(file)
|
| 149 |
+
# reps = jukemirlib.extract(
|
| 150 |
+
# audio, layers=[66], downsample_target_rate=30
|
| 151 |
+
# )[66]
|
| 152 |
+
reps = feature_func(file)[:opt.full_seq_len]
|
| 153 |
+
# save reps
|
| 154 |
+
if opt.cache_features:
|
| 155 |
+
featurename = os.path.splitext(file)[0] + ".npy"
|
| 156 |
+
np.save(featurename, reps)
|
| 157 |
+
# if in the random range, put it into the list of reps we want
|
| 158 |
+
# to actually use for generation
|
| 159 |
+
if rand_idx <= idx < rand_idx + sample_size:
|
| 160 |
+
cond_list.append(reps)
|
| 161 |
+
cond_list = torch.from_numpy(np.array(cond_list))
|
| 162 |
+
all_cond.append(cond_list)
|
| 163 |
+
all_filenames.append(file_list[rand_idx : rand_idx + sample_size])
|
| 164 |
+
|
| 165 |
+
model = EDGE(opt, opt.feature_type, opt.checkpoint)
|
| 166 |
+
model.eval()
|
| 167 |
+
|
| 168 |
+
# directory for optionally saving the dances for eval
|
| 169 |
+
fk_out = None
|
| 170 |
+
if opt.save_motions:
|
| 171 |
+
fk_out = opt.motion_save_dir
|
| 172 |
+
|
| 173 |
+
print("Generating dances")
|
| 174 |
+
for i in range(len(all_cond)):
|
| 175 |
+
data_tuple = None, all_cond[i], all_filenames[i]
|
| 176 |
+
model.render_sample(
|
| 177 |
+
data_tuple, "test", opt.render_dir, render_count=-1, fk_out=fk_out, mode="long", render=not opt.no_render
|
| 178 |
+
)
|
| 179 |
+
print("Done")
|
| 180 |
+
if torch.cuda.is_available():
|
| 181 |
+
torch.cuda.empty_cache()
|
| 182 |
+
for temp_dir in temp_dir_list:
|
| 183 |
+
temp_dir.cleanup()
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
opt = FineDance_parse_test_opt()
|
| 187 |
+
test(opt)
|
train_seq.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
import os
|
| 3 |
+
from zlib import Z_FULL_FLUSH
|
| 4 |
+
# os.environ["WANDB_API_KEY"] = "your WANDB_API_KEY" #
|
| 5 |
+
# os.environ["WANDB_MODE"] = "online"
|
| 6 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
|
| 7 |
+
import pickle
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from args import FineDance_parse_train_opt, save_arguments_to_yaml
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import wandb
|
| 16 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 17 |
+
from accelerate.state import AcceleratorState
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from dataset.FineDance_dataset import FineDance_Smpl
|
| 22 |
+
from dataset.preprocess import increment_path
|
| 23 |
+
from dataset.preprocess import My_Normalizer as Normalizer # do not use Normalizer
|
| 24 |
+
from model.adan import Adan
|
| 25 |
+
from model.diffusion import GaussianDiffusion
|
| 26 |
+
from model.model import DanceDecoder, SeqModel
|
| 27 |
+
from vis import SMPLX_Skeleton, SMPLSkeleton
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def wrap(x):
|
| 31 |
+
return {f"module.{key}": value for key, value in x.items()}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def maybe_wrap(x, num):
|
| 35 |
+
return x if num == 1 else wrap(x)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class EDGE:
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
opt,
|
| 42 |
+
feature_type,
|
| 43 |
+
checkpoint_path="",
|
| 44 |
+
normalizer=None,
|
| 45 |
+
EMA=True,
|
| 46 |
+
learning_rate=4e-4,
|
| 47 |
+
weight_decay=0.02,
|
| 48 |
+
):
|
| 49 |
+
self.opt = opt
|
| 50 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 51 |
+
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
| 52 |
+
state = AcceleratorState()
|
| 53 |
+
num_processes = state.num_processes
|
| 54 |
+
|
| 55 |
+
self.repr_dim = repr_dim = opt.nfeats
|
| 56 |
+
feature_dim = 35
|
| 57 |
+
|
| 58 |
+
self.horizon = horizon = opt.full_seq_len
|
| 59 |
+
|
| 60 |
+
self.accelerator.wait_for_everyone()
|
| 61 |
+
|
| 62 |
+
self.resume_num = 0
|
| 63 |
+
checkpoint = None
|
| 64 |
+
self.normalizer = None
|
| 65 |
+
if checkpoint_path != "":
|
| 66 |
+
checkpoint = torch.load(
|
| 67 |
+
checkpoint_path, map_location=self.accelerator.device
|
| 68 |
+
)
|
| 69 |
+
self.resume_num = int(os.path.basename(checkpoint_path).split("-")[1].split(".")[0]) # int(os.path.basenam
|
| 70 |
+
|
| 71 |
+
model = SeqModel(
|
| 72 |
+
nfeats=repr_dim,
|
| 73 |
+
seq_len=horizon,
|
| 74 |
+
latent_dim=512,
|
| 75 |
+
ff_size=1024,
|
| 76 |
+
num_layers=8,
|
| 77 |
+
num_heads=8,
|
| 78 |
+
dropout=0.1,
|
| 79 |
+
cond_feature_dim=feature_dim,
|
| 80 |
+
activation=F.gelu,
|
| 81 |
+
)
|
| 82 |
+
if opt.nfeats == 139 or opt.nfeats == 135:
|
| 83 |
+
smplx_fk = SMPLSkeleton(device=self.accelerator.device)
|
| 84 |
+
else:
|
| 85 |
+
smplx_fk = SMPLX_Skeleton(device=self.accelerator.device, batch=512000)
|
| 86 |
+
diffusion = GaussianDiffusion(
|
| 87 |
+
model,
|
| 88 |
+
opt,
|
| 89 |
+
horizon,
|
| 90 |
+
repr_dim,
|
| 91 |
+
smplx_model = smplx_fk,
|
| 92 |
+
schedule="cosine",
|
| 93 |
+
n_timestep=1000,
|
| 94 |
+
predict_epsilon=False,
|
| 95 |
+
loss_type="l2",
|
| 96 |
+
use_p2=False,
|
| 97 |
+
cond_drop_prob=0.25,
|
| 98 |
+
guidance_weight=2,
|
| 99 |
+
do_normalize = opt.do_normalize
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print(
|
| 103 |
+
"Model has {} parameters".format(sum(y.numel() for y in model.parameters()))
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.model = self.accelerator.prepare(model)
|
| 107 |
+
self.diffusion = diffusion.to(self.accelerator.device) # 为什么这里不需要prepare
|
| 108 |
+
self.smplx_fk = smplx_fk # to(self.accelerator.device)
|
| 109 |
+
optim = Adan(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
| 110 |
+
self.optim = self.accelerator.prepare(optim)
|
| 111 |
+
|
| 112 |
+
if checkpoint_path != "":
|
| 113 |
+
self.model.load_state_dict(
|
| 114 |
+
maybe_wrap(
|
| 115 |
+
checkpoint["ema_state_dict" if EMA else "model_state_dict"],
|
| 116 |
+
num_processes,
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def eval(self):
|
| 121 |
+
self.diffusion.eval()
|
| 122 |
+
|
| 123 |
+
def train(self):
|
| 124 |
+
self.diffusion.train()
|
| 125 |
+
|
| 126 |
+
def prepare(self, objects):
|
| 127 |
+
return self.accelerator.prepare(*objects)
|
| 128 |
+
|
| 129 |
+
def train_loop(self, opt):
|
| 130 |
+
print("train_dataset = FineDance_Dataset ")
|
| 131 |
+
train_dataset = FineDance_Smpl(
|
| 132 |
+
args=opt, # data/
|
| 133 |
+
istrain=True,
|
| 134 |
+
)
|
| 135 |
+
test_dataset = FineDance_Smpl(
|
| 136 |
+
args=opt,
|
| 137 |
+
istrain=False,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
num_cpus = multiprocessing.cpu_count()
|
| 141 |
+
print("batchsize=:", opt.batch_size)
|
| 142 |
+
train_data_loader = DataLoader(
|
| 143 |
+
train_dataset,
|
| 144 |
+
batch_size=opt.batch_size,
|
| 145 |
+
shuffle=True,
|
| 146 |
+
num_workers=min(int(num_cpus * 0.5), 40), # num_workers=min(int(num_cpus * 0.75), 32),
|
| 147 |
+
pin_memory=True,
|
| 148 |
+
drop_last=True,
|
| 149 |
+
)
|
| 150 |
+
test_data_loader = DataLoader(
|
| 151 |
+
test_dataset,
|
| 152 |
+
batch_size=opt.batch_size,
|
| 153 |
+
shuffle=True,
|
| 154 |
+
num_workers=2,
|
| 155 |
+
pin_memory=True,
|
| 156 |
+
drop_last=True,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
train_data_loader = self.accelerator.prepare(train_data_loader)
|
| 160 |
+
# boot up multi-gpu training. test dataloader is only on main process
|
| 161 |
+
load_loop = (
|
| 162 |
+
partial(tqdm, position=1, desc="Batch")
|
| 163 |
+
if self.accelerator.is_main_process
|
| 164 |
+
else lambda x: x
|
| 165 |
+
)
|
| 166 |
+
if self.accelerator.is_main_process:
|
| 167 |
+
save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
|
| 168 |
+
opt.exp_name = save_dir.split("/")[-1]
|
| 169 |
+
wandb.init(project=opt.wandb_pj_name, name=opt.exp_name)
|
| 170 |
+
save_dir = Path(save_dir)
|
| 171 |
+
wdir = save_dir / "weights"
|
| 172 |
+
wdir.mkdir(parents=True, exist_ok=True)
|
| 173 |
+
wandb.save("params.yaml") # 保存wandb配置到文件
|
| 174 |
+
yaml_path = os.path.join(wdir, 'parameters.yaml')
|
| 175 |
+
save_arguments_to_yaml(opt, yaml_path)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
self.accelerator.wait_for_everyone()
|
| 179 |
+
for epoch in range(1, opt.epochs + 1):
|
| 180 |
+
print("epoch:", epoch+self.resume_num)
|
| 181 |
+
avg_loss = 0
|
| 182 |
+
avg_vloss = 0
|
| 183 |
+
avg_fkloss = 0
|
| 184 |
+
avg_footloss = 0
|
| 185 |
+
|
| 186 |
+
# train
|
| 187 |
+
self.train()
|
| 188 |
+
for step, (x, cond, filename) in enumerate(
|
| 189 |
+
load_loop(train_data_loader)
|
| 190 |
+
):
|
| 191 |
+
if opt.nfeats == 139 or opt.nfeats==135:
|
| 192 |
+
x = x[:, :, :139]
|
| 193 |
+
|
| 194 |
+
total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion(
|
| 195 |
+
x, cond, t_override=None
|
| 196 |
+
)
|
| 197 |
+
# print("3")
|
| 198 |
+
self.optim.zero_grad()
|
| 199 |
+
self.accelerator.backward(total_loss)
|
| 200 |
+
self.optim.step()
|
| 201 |
+
|
| 202 |
+
# ema update and train loss update only on main
|
| 203 |
+
if self.accelerator.is_main_process:
|
| 204 |
+
avg_loss += loss.detach().cpu().numpy()
|
| 205 |
+
avg_vloss += v_loss.detach().cpu().numpy()
|
| 206 |
+
avg_fkloss += fk_loss.detach().cpu().numpy()
|
| 207 |
+
avg_footloss += foot_loss.detach().cpu().numpy()
|
| 208 |
+
if step % opt.ema_interval == 0:
|
| 209 |
+
self.diffusion.ema.update_model_average(
|
| 210 |
+
self.diffusion.master_model, self.diffusion.model
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
#-----------------------------------------------------------------------------------------------------------
|
| 214 |
+
# test
|
| 215 |
+
# Save model
|
| 216 |
+
|
| 217 |
+
if ((epoch+self.resume_num) % opt.save_interval) == 0 or epoch<=1:
|
| 218 |
+
# everyone waits here for the val loop to finish ( don't start next train epoch early)
|
| 219 |
+
self.accelerator.wait_for_everyone()
|
| 220 |
+
self.eval() # debug!
|
| 221 |
+
# save only if on main thread
|
| 222 |
+
if self.accelerator.is_main_process:
|
| 223 |
+
# self.eval()
|
| 224 |
+
# log
|
| 225 |
+
avg_loss /= len(train_data_loader)
|
| 226 |
+
avg_vloss /= len(train_data_loader)
|
| 227 |
+
avg_fkloss /= len(train_data_loader)
|
| 228 |
+
avg_footloss /= len(train_data_loader)
|
| 229 |
+
log_dict = {
|
| 230 |
+
"Train Loss": avg_loss,
|
| 231 |
+
"V Loss": avg_vloss,
|
| 232 |
+
"FK Loss": avg_fkloss,
|
| 233 |
+
"Foot Loss": avg_footloss,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
wandb.log(log_dict)
|
| 237 |
+
|
| 238 |
+
ckpt = {
|
| 239 |
+
"ema_state_dict": self.diffusion.master_model.state_dict(), # 经过accelerate prepare的模型,在保存时需要unwrap,反之不需要
|
| 240 |
+
"model_state_dict": self.accelerator.unwrap_model(
|
| 241 |
+
self.model
|
| 242 |
+
).state_dict(),
|
| 243 |
+
"optimizer_state_dict": self.optim.state_dict(),
|
| 244 |
+
"normalizer": self.normalizer,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
torch.save(ckpt, os.path.join(wdir, f"train-{epoch+self.resume_num}.pt"))
|
| 248 |
+
print(f"[MODEL SAVED at Epoch {epoch+self.resume_num}]")
|
| 249 |
+
|
| 250 |
+
# generate a sample
|
| 251 |
+
render_count = 2
|
| 252 |
+
shape = (render_count, self.horizon, self.opt.nfeats)
|
| 253 |
+
print("Generating Sample")
|
| 254 |
+
# draw a music from the test dataset
|
| 255 |
+
(x, cond, filename) = next(iter(test_data_loader))
|
| 256 |
+
# if opt.do_normalize:
|
| 257 |
+
# x = self.normalizer.normalize(x)
|
| 258 |
+
|
| 259 |
+
if opt.nfeats == 139 or opt.nfeats==135:
|
| 260 |
+
x = x[:, :, :139]
|
| 261 |
+
|
| 262 |
+
cond = cond.to(self.accelerator.device)
|
| 263 |
+
# name_iter = name_iter+1
|
| 264 |
+
self.diffusion.render_sample(
|
| 265 |
+
shape,
|
| 266 |
+
cond[:render_count],
|
| 267 |
+
self.normalizer,
|
| 268 |
+
epoch+self.resume_num,
|
| 269 |
+
render_out = os.path.join(opt.render_dir, "train_" + opt.exp_name), # render out
|
| 270 |
+
fk_out = os.path.join(opt.render_dir, "train_" + opt.exp_name),
|
| 271 |
+
name=filename[:render_count],
|
| 272 |
+
# name = str(epoch) + str(name_iter).zfill(3)
|
| 273 |
+
sound=True,
|
| 274 |
+
)
|
| 275 |
+
#-----------------------------------------------------------------------------------------------------------
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if self.accelerator.is_main_process:
|
| 279 |
+
wandb.run.finish()
|
| 280 |
+
|
| 281 |
+
def render_sample(
|
| 282 |
+
self, data_tuple, label, render_dir, render_count=-1, mode='normal', fk_out=None, render=True,
|
| 283 |
+
):
|
| 284 |
+
_, cond, wavname = data_tuple
|
| 285 |
+
assert len(cond.shape) == 3
|
| 286 |
+
if render_count < 0:
|
| 287 |
+
render_count = len(cond)
|
| 288 |
+
shape = (render_count, self.horizon, self.repr_dim)
|
| 289 |
+
cond = cond.to(self.accelerator.device).float()
|
| 290 |
+
self.diffusion.render_sample(
|
| 291 |
+
shape,
|
| 292 |
+
cond[:render_count],
|
| 293 |
+
self.normalizer,
|
| 294 |
+
label,
|
| 295 |
+
render_dir,
|
| 296 |
+
name=wavname[:render_count],
|
| 297 |
+
sound=True,
|
| 298 |
+
mode=mode,
|
| 299 |
+
fk_out=fk_out,
|
| 300 |
+
render=render
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
def train(opt):
|
| 304 |
+
model = EDGE(opt, opt.feature_type)
|
| 305 |
+
model.train_loop(opt)
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
opt = FineDance_parse_train_opt()
|
| 309 |
+
command = ' '.join(sys.argv)
|
| 310 |
+
if not os.path.exists(os.path.join(opt.project, opt.exp_name)):
|
| 311 |
+
os.makedirs(os.path.join(opt.project, opt.exp_name), exist_ok=False)
|
| 312 |
+
with open(os.path.join(opt.project, opt.exp_name, 'command.txt'), 'w') as f:
|
| 313 |
+
f.write(command)
|
| 314 |
+
|
| 315 |
+
yaml_path = os.path.join(opt.project, opt.exp_name, 'parameters.yaml')
|
| 316 |
+
save_arguments_to_yaml(opt, yaml_path)
|
| 317 |
+
|
| 318 |
+
train(opt)
|
vis.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
from tempfile import TemporaryDirectory
|
| 5 |
+
|
| 6 |
+
import librosa as lr
|
| 7 |
+
import matplotlib.animation as animation
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from mpl_toolkits.mplot3d import axes3d
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import torch
|
| 14 |
+
from matplotlib import cm
|
| 15 |
+
from matplotlib.colors import ListedColormap
|
| 16 |
+
from pytorch3d.transforms import (axis_angle_to_quaternion, quaternion_apply,
|
| 17 |
+
quaternion_multiply)
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from typing import NewType
|
| 20 |
+
Tensor = NewType('Tensor', torch.Tensor)
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
try:
|
| 23 |
+
import pickle5 as pickle
|
| 24 |
+
except ImportError:
|
| 25 |
+
import pickle
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
smpl_joints = [
|
| 29 |
+
"root", # 0
|
| 30 |
+
"lhip", # 1
|
| 31 |
+
"rhip", # 2
|
| 32 |
+
"belly", # 3
|
| 33 |
+
"lknee", # 4
|
| 34 |
+
"rknee", # 5
|
| 35 |
+
"spine", # 6
|
| 36 |
+
"lankle",# 7
|
| 37 |
+
"rankle",# 8
|
| 38 |
+
"chest", # 9
|
| 39 |
+
"ltoes", # 10
|
| 40 |
+
"rtoes", # 11
|
| 41 |
+
"neck", # 12
|
| 42 |
+
"linshoulder", # 13
|
| 43 |
+
"rinshoulder", # 14
|
| 44 |
+
"head", # 15
|
| 45 |
+
"lshoulder", # 16
|
| 46 |
+
"rshoulder", # 17
|
| 47 |
+
"lelbow", # 18
|
| 48 |
+
"relbow", # 19
|
| 49 |
+
"lwrist", # 20
|
| 50 |
+
"rwrist", # 21
|
| 51 |
+
"lhand", # 22
|
| 52 |
+
"rhand", # 23
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
smplh_joints = [
|
| 56 |
+
'pelvis',
|
| 57 |
+
'left_hip',
|
| 58 |
+
'right_hip',
|
| 59 |
+
'spine1',
|
| 60 |
+
'left_knee',
|
| 61 |
+
'right_knee',
|
| 62 |
+
'spine2',
|
| 63 |
+
'left_ankle',
|
| 64 |
+
'right_ankle',
|
| 65 |
+
'spine3',
|
| 66 |
+
'left_foot',
|
| 67 |
+
'right_foot',
|
| 68 |
+
'neck',
|
| 69 |
+
'left_collar',
|
| 70 |
+
'right_collar',
|
| 71 |
+
'head',
|
| 72 |
+
'left_shoulder',
|
| 73 |
+
'right_shoulder',
|
| 74 |
+
'left_elbow',
|
| 75 |
+
'right_elbow',
|
| 76 |
+
'left_wrist',
|
| 77 |
+
'right_wrist',
|
| 78 |
+
'left_index1',
|
| 79 |
+
'left_index2',
|
| 80 |
+
'left_index3',
|
| 81 |
+
'left_middle1',
|
| 82 |
+
'left_middle2',
|
| 83 |
+
'left_middle3',
|
| 84 |
+
'left_pinky1',
|
| 85 |
+
'left_pinky2',
|
| 86 |
+
'left_pinky3',
|
| 87 |
+
'left_ring1',
|
| 88 |
+
'left_ring2',
|
| 89 |
+
'left_ring3',
|
| 90 |
+
'left_thumb1',
|
| 91 |
+
'left_thumb2',
|
| 92 |
+
'left_thumb3',
|
| 93 |
+
'right_index1',
|
| 94 |
+
'right_index2',
|
| 95 |
+
'right_index3',
|
| 96 |
+
'right_middle1',
|
| 97 |
+
'right_middle2',
|
| 98 |
+
'right_middle3',
|
| 99 |
+
'right_pinky1',
|
| 100 |
+
'right_pinky2',
|
| 101 |
+
'right_pinky3',
|
| 102 |
+
'right_ring1',
|
| 103 |
+
'right_ring2',
|
| 104 |
+
'right_ring3',
|
| 105 |
+
'right_thumb1',
|
| 106 |
+
'right_thumb2',
|
| 107 |
+
'right_thumb3'
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
smplx_joints = [
|
| 112 |
+
'pelvis',
|
| 113 |
+
'left_hip',
|
| 114 |
+
'right_hip',
|
| 115 |
+
'spine1',
|
| 116 |
+
'left_knee',
|
| 117 |
+
'right_knee',
|
| 118 |
+
'spine2',
|
| 119 |
+
'left_ankle',
|
| 120 |
+
'right_ankle',
|
| 121 |
+
'spine3',
|
| 122 |
+
'left_foot',
|
| 123 |
+
'right_foot',
|
| 124 |
+
'neck',
|
| 125 |
+
'left_collar',
|
| 126 |
+
'right_collar',
|
| 127 |
+
'head',
|
| 128 |
+
'left_shoulder',
|
| 129 |
+
'right_shoulder',
|
| 130 |
+
'left_elbow',
|
| 131 |
+
'right_elbow',
|
| 132 |
+
'left_wrist',
|
| 133 |
+
'right_wrist',
|
| 134 |
+
'jaw',
|
| 135 |
+
'left_eye_smplhf',
|
| 136 |
+
'right_eye_smplhf',
|
| 137 |
+
'left_index1',
|
| 138 |
+
'left_index2',
|
| 139 |
+
'left_index3',
|
| 140 |
+
'left_middle1',
|
| 141 |
+
'left_middle2',
|
| 142 |
+
'left_middle3',
|
| 143 |
+
'left_pinky1',
|
| 144 |
+
'left_pinky2',
|
| 145 |
+
'left_pinky3',
|
| 146 |
+
'left_ring1',
|
| 147 |
+
'left_ring2',
|
| 148 |
+
'left_ring3',
|
| 149 |
+
'left_thumb1',
|
| 150 |
+
'left_thumb2',
|
| 151 |
+
'left_thumb3',
|
| 152 |
+
'right_index1',
|
| 153 |
+
'right_index2',
|
| 154 |
+
'right_index3',
|
| 155 |
+
'right_middle1',
|
| 156 |
+
'right_middle2',
|
| 157 |
+
'right_middle3',
|
| 158 |
+
'right_pinky1',
|
| 159 |
+
'right_pinky2',
|
| 160 |
+
'right_pinky3',
|
| 161 |
+
'right_ring1',
|
| 162 |
+
'right_ring2',
|
| 163 |
+
'right_ring3',
|
| 164 |
+
'right_thumb1',
|
| 165 |
+
'right_thumb2',
|
| 166 |
+
'right_thumb3'
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
smpl_parents = [
|
| 171 |
+
-1,
|
| 172 |
+
0,
|
| 173 |
+
0,
|
| 174 |
+
0,
|
| 175 |
+
1,
|
| 176 |
+
2,
|
| 177 |
+
3,
|
| 178 |
+
4,
|
| 179 |
+
5,
|
| 180 |
+
6,
|
| 181 |
+
7,
|
| 182 |
+
8,
|
| 183 |
+
9,
|
| 184 |
+
9,
|
| 185 |
+
9,
|
| 186 |
+
12,
|
| 187 |
+
13,
|
| 188 |
+
14,
|
| 189 |
+
16,
|
| 190 |
+
17,
|
| 191 |
+
18,
|
| 192 |
+
19,
|
| 193 |
+
20,
|
| 194 |
+
21,
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
smplh_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14,
|
| 198 |
+
16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34,
|
| 199 |
+
35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50]
|
| 200 |
+
|
| 201 |
+
smplx_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52, 53]
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
smpl_offsets = [
|
| 205 |
+
[0.0, 0.0, 0.0],
|
| 206 |
+
[0.05858135, -0.08228004, -0.01766408],
|
| 207 |
+
[-0.06030973, -0.09051332, -0.01354254],
|
| 208 |
+
[0.00443945, 0.12440352, -0.03838522],
|
| 209 |
+
[0.04345142, -0.38646945, 0.008037],
|
| 210 |
+
[-0.04325663, -0.38368791, -0.00484304],
|
| 211 |
+
[0.00448844, 0.1379564, 0.02682033],
|
| 212 |
+
[-0.01479032, -0.42687458, -0.037428],
|
| 213 |
+
[0.01905555, -0.4200455, -0.03456167],
|
| 214 |
+
[-0.00226458, 0.05603239, 0.00285505],
|
| 215 |
+
[0.04105436, -0.06028581, 0.12204243],
|
| 216 |
+
[-0.03483987, -0.06210566, 0.13032329],
|
| 217 |
+
[-0.0133902, 0.21163553, -0.03346758],
|
| 218 |
+
[0.07170245, 0.11399969, -0.01889817],
|
| 219 |
+
[-0.08295366, 0.11247234, -0.02370739],
|
| 220 |
+
[0.01011321, 0.08893734, 0.05040987],
|
| 221 |
+
[0.12292141, 0.04520509, -0.019046],
|
| 222 |
+
[-0.11322832, 0.04685326, -0.00847207],
|
| 223 |
+
[0.2553319, -0.01564902, -0.02294649],
|
| 224 |
+
[-0.26012748, -0.01436928, -0.03126873],
|
| 225 |
+
[0.26570925, 0.01269811, -0.00737473],
|
| 226 |
+
[-0.26910836, 0.00679372, -0.00602676],
|
| 227 |
+
[0.08669055, -0.01063603, -0.01559429],
|
| 228 |
+
[-0.0887537, -0.00865157, -0.01010708],
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def set_line_data_3d(line, x):
|
| 233 |
+
line.set_data(x[:, :2].T)
|
| 234 |
+
line.set_3d_properties(x[:, 2])
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def set_scatter_data_3d(scat, x, c):
|
| 238 |
+
scat.set_offsets(x[:, :2])
|
| 239 |
+
scat.set_3d_properties(x[:, 2], "z")
|
| 240 |
+
scat.set_facecolors([c])
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def get_axrange(poses):
|
| 244 |
+
pose = poses[0]
|
| 245 |
+
x_min = pose[:, 0].min()
|
| 246 |
+
x_max = pose[:, 0].max()
|
| 247 |
+
|
| 248 |
+
y_min = pose[:, 1].min()
|
| 249 |
+
y_max = pose[:, 1].max()
|
| 250 |
+
|
| 251 |
+
z_min = pose[:, 2].min()
|
| 252 |
+
z_max = pose[:, 2].max()
|
| 253 |
+
|
| 254 |
+
xdiff = x_max - x_min
|
| 255 |
+
ydiff = y_max - y_min
|
| 256 |
+
zdiff = z_max - z_min
|
| 257 |
+
|
| 258 |
+
biggestdiff = max([xdiff, ydiff, zdiff])
|
| 259 |
+
return biggestdiff
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def plot_single_pose(num, poses, lines, ax, axrange, scat, contact, ske_parents):
|
| 263 |
+
pose = poses[num]
|
| 264 |
+
static = contact[num]
|
| 265 |
+
indices = [7, 8, 10, 11]
|
| 266 |
+
|
| 267 |
+
for i, (point, idx) in enumerate(zip(scat, indices)):
|
| 268 |
+
position = pose[idx : idx + 1]
|
| 269 |
+
color = "r" if static[i] else "g"
|
| 270 |
+
set_scatter_data_3d(point, position, color)
|
| 271 |
+
|
| 272 |
+
for i, (p, line) in enumerate(zip(ske_parents, lines)):
|
| 273 |
+
# don't plot root
|
| 274 |
+
if i == 0:
|
| 275 |
+
continue
|
| 276 |
+
# stack to create a line
|
| 277 |
+
data = np.stack((pose[i], pose[p]), axis=0)
|
| 278 |
+
set_line_data_3d(line, data)
|
| 279 |
+
|
| 280 |
+
if num == 0:
|
| 281 |
+
if isinstance(axrange, int):
|
| 282 |
+
axrange = (axrange, axrange, axrange)
|
| 283 |
+
xcenter, ycenter, zcenter = 0, 0, 2.5
|
| 284 |
+
stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2
|
| 285 |
+
|
| 286 |
+
x_min, x_max = xcenter - stepx, xcenter + stepx
|
| 287 |
+
y_min, y_max = ycenter - stepy, ycenter + stepy
|
| 288 |
+
z_min, z_max = zcenter - stepz, zcenter + stepz
|
| 289 |
+
|
| 290 |
+
ax.set_xlim(x_min, x_max)
|
| 291 |
+
ax.set_ylim(y_min, y_max)
|
| 292 |
+
ax.set_zlim(z_min, z_max)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def skeleton_render(
|
| 296 |
+
poses,
|
| 297 |
+
epoch=0,
|
| 298 |
+
out="renders",
|
| 299 |
+
name="",
|
| 300 |
+
sound=True,
|
| 301 |
+
stitch=False,
|
| 302 |
+
sound_folder="ood_sliced",
|
| 303 |
+
contact=None,
|
| 304 |
+
render=True,
|
| 305 |
+
smpl_mode="smpl", # 是否渲染双手
|
| 306 |
+
):
|
| 307 |
+
if render:
|
| 308 |
+
if smpl_mode=="smpl":
|
| 309 |
+
poses = np.concatenate((poses[:, :23, :], np.expand_dims(poses[:, 37, :], axis=1)), axis=1)
|
| 310 |
+
ske_parents = smpl_parents
|
| 311 |
+
elif smpl_mode == "smplx":
|
| 312 |
+
ske_parents = smplx_parents
|
| 313 |
+
|
| 314 |
+
# generate the pose with FK
|
| 315 |
+
Path(out).mkdir(parents=True, exist_ok=True)
|
| 316 |
+
num_steps = poses.shape[0] #
|
| 317 |
+
|
| 318 |
+
fig = plt.figure()
|
| 319 |
+
ax = fig.add_subplot(projection="3d")
|
| 320 |
+
|
| 321 |
+
point = np.array([0, 0, 1])
|
| 322 |
+
normal = np.array([0, 0, 1])
|
| 323 |
+
d = -point.dot(normal)
|
| 324 |
+
xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2), np.linspace(-1.5, 1.5, 2))
|
| 325 |
+
z = (-normal[0] * xx - normal[1] * yy - d) * 1.0 / normal[2]
|
| 326 |
+
# plot the plane
|
| 327 |
+
ax.plot_surface(xx, yy, z, zorder=-11, cmap=cm.twilight)
|
| 328 |
+
# Create lines initially without data
|
| 329 |
+
lines = [
|
| 330 |
+
ax.plot([], [], [], zorder=10, linewidth=1.5)[0]
|
| 331 |
+
for _ in ske_parents
|
| 332 |
+
]
|
| 333 |
+
scat = [
|
| 334 |
+
ax.scatter([], [], [], zorder=10, s=0, cmap=ListedColormap(["r", "g", "b"]))
|
| 335 |
+
for _ in range(4)
|
| 336 |
+
]
|
| 337 |
+
axrange = 3
|
| 338 |
+
|
| 339 |
+
# create contact labels
|
| 340 |
+
feet = poses[:, (7, 8, 10, 11)]
|
| 341 |
+
feetv = np.zeros(feet.shape[:2])
|
| 342 |
+
feetv[:-1] = np.linalg.norm(feet[1:] - feet[:-1], axis=-1)
|
| 343 |
+
if contact is None:
|
| 344 |
+
contact = feetv < 0.01
|
| 345 |
+
else:
|
| 346 |
+
contact = contact > 0.95
|
| 347 |
+
|
| 348 |
+
# Creating the Animation object
|
| 349 |
+
anim = animation.FuncAnimation(
|
| 350 |
+
fig,
|
| 351 |
+
plot_single_pose,
|
| 352 |
+
num_steps,
|
| 353 |
+
fargs=(poses, lines, ax, axrange, scat, contact, ske_parents),
|
| 354 |
+
interval=1000 // 30,
|
| 355 |
+
)
|
| 356 |
+
if sound:
|
| 357 |
+
# make a temporary directory to save the intermediate gif in
|
| 358 |
+
if render:
|
| 359 |
+
temp_dir = TemporaryDirectory()
|
| 360 |
+
gifname = os.path.join(temp_dir.name, f"{epoch}.gif")
|
| 361 |
+
anim.save(gifname)
|
| 362 |
+
|
| 363 |
+
# stitch wavs
|
| 364 |
+
if stitch:
|
| 365 |
+
assert type(name) == list # must be a list of names to do stitching
|
| 366 |
+
name_ = [os.path.splitext(x)[0] + ".wav" for x in name]
|
| 367 |
+
audio, sr = lr.load(name_[0], sr=None)
|
| 368 |
+
ll, half = len(audio), len(audio) // 2
|
| 369 |
+
total_wav = np.zeros(ll + half * (len(name_) - 1))
|
| 370 |
+
total_wav[:ll] = audio
|
| 371 |
+
idx = ll
|
| 372 |
+
for n_ in name_[1:]:
|
| 373 |
+
audio, sr = lr.load(n_, sr=None)
|
| 374 |
+
total_wav[idx : idx + half] = audio[half:]
|
| 375 |
+
idx += half
|
| 376 |
+
# save a dummy spliced audio
|
| 377 |
+
audioname = f"{temp_dir.name}/tempsound.wav" if render else os.path.join(out, f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.wav')
|
| 378 |
+
sf.write(audioname, total_wav, sr)
|
| 379 |
+
outname = os.path.join(
|
| 380 |
+
out,
|
| 381 |
+
f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.mp4',
|
| 382 |
+
)
|
| 383 |
+
else:
|
| 384 |
+
assert type(name) == str
|
| 385 |
+
assert name != "", "Must provide an audio filename"
|
| 386 |
+
audioname = name
|
| 387 |
+
outname = os.path.join(
|
| 388 |
+
out, f"{epoch}_{os.path.splitext(os.path.basename(name))[0]}.mp4"
|
| 389 |
+
)
|
| 390 |
+
if render:
|
| 391 |
+
print(f"ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}")
|
| 392 |
+
out = os.system(
|
| 393 |
+
f"/home/lrh/Documents/ffmpeg-6.0-amd64-static/ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}"
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
if render:
|
| 397 |
+
# actually save the gif
|
| 398 |
+
path = os.path.normpath(name)
|
| 399 |
+
pathparts = path.split(os.sep)
|
| 400 |
+
gifname = os.path.join(out, f"{pathparts[-1][:-4]}.gif")
|
| 401 |
+
anim.save(gifname, savefig_kwargs={"transparent": True, "facecolor": "none"},)
|
| 402 |
+
plt.close()
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class SMPLSkeleton:
|
| 406 |
+
def __init__(
|
| 407 |
+
self, device=None,
|
| 408 |
+
):
|
| 409 |
+
offsets = smpl_offsets
|
| 410 |
+
parents = smpl_parents
|
| 411 |
+
assert len(offsets) == len(parents)
|
| 412 |
+
|
| 413 |
+
self._offsets = torch.Tensor(offsets) #.to(device)
|
| 414 |
+
self._parents = np.array(parents)
|
| 415 |
+
self._compute_metadata()
|
| 416 |
+
|
| 417 |
+
def _compute_metadata(self):
|
| 418 |
+
self._has_children = np.zeros(len(self._parents)).astype(bool)
|
| 419 |
+
for i, parent in enumerate(self._parents):
|
| 420 |
+
if parent != -1:
|
| 421 |
+
self._has_children[parent] = True
|
| 422 |
+
|
| 423 |
+
self._children = []
|
| 424 |
+
for i, parent in enumerate(self._parents):
|
| 425 |
+
self._children.append([])
|
| 426 |
+
for i, parent in enumerate(self._parents):
|
| 427 |
+
if parent != -1:
|
| 428 |
+
self._children[parent].append(i)
|
| 429 |
+
|
| 430 |
+
def forward(self, rotations, root_positions):
|
| 431 |
+
"""
|
| 432 |
+
Perform forward kinematics using the given trajectory and local rotations.
|
| 433 |
+
Arguments (where N = batch size, L = sequence length, J = number of joints):
|
| 434 |
+
-- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint.
|
| 435 |
+
-- root_positions: (N, L, 3) tensor describing the root joint positions.
|
| 436 |
+
"""
|
| 437 |
+
assert len(rotations.shape) == 4
|
| 438 |
+
assert len(root_positions.shape) == 3
|
| 439 |
+
# transform from axis angle to quaternion
|
| 440 |
+
fk_device = rotations.device
|
| 441 |
+
self._offsets.to(fk_device)
|
| 442 |
+
rotations = axis_angle_to_quaternion(rotations)
|
| 443 |
+
|
| 444 |
+
positions_world = []
|
| 445 |
+
rotations_world = []
|
| 446 |
+
|
| 447 |
+
expanded_offsets = self._offsets.expand(
|
| 448 |
+
rotations.shape[0],
|
| 449 |
+
rotations.shape[1],
|
| 450 |
+
self._offsets.shape[0],
|
| 451 |
+
self._offsets.shape[1],
|
| 452 |
+
).to(fk_device)
|
| 453 |
+
|
| 454 |
+
# Parallelize along the batch and time dimensions
|
| 455 |
+
for i in range(self._offsets.shape[0]):
|
| 456 |
+
if self._parents[i] == -1:
|
| 457 |
+
positions_world.append(root_positions)
|
| 458 |
+
rotations_world.append(rotations[:, :, 0])
|
| 459 |
+
else:
|
| 460 |
+
positions_world.append(
|
| 461 |
+
quaternion_apply(
|
| 462 |
+
rotations_world[self._parents[i]], expanded_offsets[:, :, i]
|
| 463 |
+
)
|
| 464 |
+
+ positions_world[self._parents[i]]
|
| 465 |
+
)
|
| 466 |
+
if self._has_children[i]:
|
| 467 |
+
rotations_world.append(
|
| 468 |
+
quaternion_multiply(
|
| 469 |
+
rotations_world[self._parents[i]], rotations[:, :, i]
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
# This joint is a terminal node -> it would be useless to compute the transformation
|
| 474 |
+
rotations_world.append(None)
|
| 475 |
+
|
| 476 |
+
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@torch.no_grad()
|
| 480 |
+
class SMPLX_Skeleton:
|
| 481 |
+
def __init__(
|
| 482 |
+
self, device=None, batch=64,
|
| 483 |
+
):
|
| 484 |
+
# offsets = smpl_offsets
|
| 485 |
+
self.device = device
|
| 486 |
+
self.parents = smplx_parents
|
| 487 |
+
self.J = np.load(os.path.join(os.path.dirname(__file__), "smplx_neu_J_1.npy"))
|
| 488 |
+
self.J = torch.from_numpy(self.J).to(device).unsqueeze(dim=0).repeat(batch, 1, 1)
|
| 489 |
+
|
| 490 |
+
def batch_rodrigues(self, rot_vecs: Tensor, epsilon: float = 1e-8,) -> Tensor:
|
| 491 |
+
''' Calculates the rotation matrices for a batch of rotation vectors
|
| 492 |
+
Parameters
|
| 493 |
+
----------
|
| 494 |
+
rot_vecs: torch.tensor Nx3
|
| 495 |
+
array of N axis-angle vectors
|
| 496 |
+
Returns
|
| 497 |
+
-------
|
| 498 |
+
R: torch.tensor Nx3x3
|
| 499 |
+
The rotation matrices for the given axis-angle parameters
|
| 500 |
+
'''
|
| 501 |
+
batch_size = rot_vecs.shape[0]
|
| 502 |
+
device, dtype = rot_vecs.device, rot_vecs.dtype
|
| 503 |
+
|
| 504 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
| 505 |
+
rot_dir = rot_vecs / angle
|
| 506 |
+
|
| 507 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
| 508 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
| 509 |
+
|
| 510 |
+
# Bx1 arrays
|
| 511 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
| 512 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
| 513 |
+
|
| 514 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
| 515 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
|
| 516 |
+
.view((batch_size, 3, 3))
|
| 517 |
+
|
| 518 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
| 519 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
| 520 |
+
return rot_mat
|
| 521 |
+
|
| 522 |
+
def batch_rigid_transform(self,
|
| 523 |
+
rot_mats: Tensor,
|
| 524 |
+
joints: Tensor,
|
| 525 |
+
parents: Tensor,
|
| 526 |
+
dtype=torch.float32
|
| 527 |
+
) -> Tensor:
|
| 528 |
+
"""
|
| 529 |
+
Applies a batch of rigid transformations to the joints
|
| 530 |
+
|
| 531 |
+
Parameters
|
| 532 |
+
----------
|
| 533 |
+
rot_mats : torch.tensor BxNx3x3
|
| 534 |
+
Tensor of rotation matrices
|
| 535 |
+
joints : torch.tensor BxNx3
|
| 536 |
+
Locations of joints
|
| 537 |
+
parents : torch.tensor BxN
|
| 538 |
+
The kinematic tree of each object
|
| 539 |
+
dtype : torch.dtype, optional:
|
| 540 |
+
The data type of the created tensors, the default is torch.float32
|
| 541 |
+
|
| 542 |
+
Returns
|
| 543 |
+
-------
|
| 544 |
+
posed_joints : torch.tensor BxNx3
|
| 545 |
+
The locations of the joints after applying the pose rotations
|
| 546 |
+
rel_transforms : torch.tensor BxNx4x4
|
| 547 |
+
The relative (with respect to the root joint) rigid transformations
|
| 548 |
+
for all the joints
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
joints = torch.unsqueeze(joints, dim=-1)
|
| 552 |
+
# joints_check = joints.detach().cpu().numpy()
|
| 553 |
+
|
| 554 |
+
rel_joints = joints.clone()
|
| 555 |
+
rel_joints[:, 1:] -= joints[:, parents[1:]]
|
| 556 |
+
|
| 557 |
+
transforms_mat = self.transform_mat(
|
| 558 |
+
rot_mats.reshape(-1, 3, 3),
|
| 559 |
+
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
|
| 560 |
+
|
| 561 |
+
transform_chain = [transforms_mat[:, 0]]
|
| 562 |
+
for i in range(1, parents.shape[0]):
|
| 563 |
+
# Subtract the joint location at the rest pose
|
| 564 |
+
# No need for rotation, since it's identity when at rest
|
| 565 |
+
curr_res = torch.matmul(transform_chain[parents[i]],
|
| 566 |
+
transforms_mat[:, i])
|
| 567 |
+
transform_chain.append(curr_res)
|
| 568 |
+
|
| 569 |
+
transforms = torch.stack(transform_chain, dim=1)
|
| 570 |
+
|
| 571 |
+
# The last column of the transformations contains the posed joints
|
| 572 |
+
posed_joints = transforms[:, :, :3, 3]
|
| 573 |
+
|
| 574 |
+
# joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
| 575 |
+
|
| 576 |
+
# rel_transforms = transforms - F.pad(
|
| 577 |
+
# torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
|
| 578 |
+
|
| 579 |
+
return posed_joints #, rel_transforms
|
| 580 |
+
|
| 581 |
+
def transform_mat(self, R: Tensor, t: Tensor) -> Tensor:
|
| 582 |
+
''' Creates a batch of transformation matrices
|
| 583 |
+
Args:
|
| 584 |
+
- R: Bx3x3 array of a batch of rotation matrices
|
| 585 |
+
- t: Bx3x1 array of a batch of translation vectors
|
| 586 |
+
Returns:
|
| 587 |
+
- T: Bx4x4 Transformation matrix
|
| 588 |
+
'''
|
| 589 |
+
# No padding left or right, only add an extra row
|
| 590 |
+
return torch.cat([F.pad(R, [0, 0, 0, 1]),
|
| 591 |
+
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
|
| 592 |
+
|
| 593 |
+
def motion_data_load_process(self, motionfile):
|
| 594 |
+
if motionfile.split(".")[-1] == "pkl":
|
| 595 |
+
pkl_data = pickle.load(open(motionfile, "rb"))
|
| 596 |
+
if "pos" in pkl_data.keys():
|
| 597 |
+
local_q_165 = torch.from_numpy(pkl_data["q"]).to(self.device).float()
|
| 598 |
+
root_pos = torch.from_numpy(pkl_data["pos"]).to(self.device).float()
|
| 599 |
+
root_pos = root_pos[:, :] - root_pos[0, :]
|
| 600 |
+
return local_q_165, root_pos
|
| 601 |
+
else:
|
| 602 |
+
smpl_poses = pkl_data["smpl_poses"]
|
| 603 |
+
if smpl_poses.shape[0] != 150 and smpl_poses.shape[0] != 300:
|
| 604 |
+
smpl_poses = smpl_poses.reshape(150, -1)
|
| 605 |
+
# modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1)
|
| 606 |
+
# assert modata.shape[1] == 159
|
| 607 |
+
# modata = torch.from_numpy(modata).to(f'cuda:{args.gpu}')
|
| 608 |
+
root_pos = pkl_data["smpl_trans"]
|
| 609 |
+
|
| 610 |
+
local_q = torch.from_numpy(smpl_poses).to(self.device).float()
|
| 611 |
+
root_pos = torch.from_numpy(root_pos).to(self.device).float()
|
| 612 |
+
local_q_165 = torch.cat([local_q[:, :66], torch.zeros([local_q.shape[0], 9], device=local_q.device, dtype=torch.float32), local_q[:, 66:]], dim=1).to(self.device).float()
|
| 613 |
+
root_pos = root_pos[:, :] - root_pos[0, :]
|
| 614 |
+
return local_q_165, root_pos
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def forward(self, rotations, root_positions):
|
| 618 |
+
"""
|
| 619 |
+
Perform forward kinematics using the given trajectory and local rotations.
|
| 620 |
+
Arguments (where N = batch size, L = sequence length, J = number of joints):
|
| 621 |
+
-- rotations: (N, 156) 或 (N, 165)
|
| 622 |
+
-- root_positions: (N, 3)
|
| 623 |
+
输出: N, 55, 3 关节点全局坐标
|
| 624 |
+
"""
|
| 625 |
+
# assert len(rotations.shape) == 4
|
| 626 |
+
# assert len(root_positions.shape) == 3
|
| 627 |
+
# print(fk_device)
|
| 628 |
+
fk_device = rotations.device
|
| 629 |
+
if rotations.shape[1] == 156:
|
| 630 |
+
local_q_165 = torch.cat([rotations[:, :66], torch.zeros([rotations.shape[0], 9], device=fk_device, dtype=torch.float32), rotations[:, 66:]], dim=1).to(fk_device).float()
|
| 631 |
+
elif rotations.shape[1] == 165:
|
| 632 |
+
local_q_165 = rotations.to(fk_device).float()
|
| 633 |
+
else:
|
| 634 |
+
print("rotations shape error", rotations.shape)
|
| 635 |
+
sys.exit(0)
|
| 636 |
+
|
| 637 |
+
root_pos = root_positions.to(fk_device).float()
|
| 638 |
+
assert local_q_165.shape[1] == 165
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
B, C = local_q_165.shape
|
| 642 |
+
# print("local_q shape is:", local_q_165.shape)
|
| 643 |
+
rot_mats = self.batch_rodrigues(local_q_165.view(-1, 3)).view(
|
| 644 |
+
[B, -1, 3, 3])
|
| 645 |
+
# J = np.load("/data/lrh/project/Dance/mdm_v2/model/smplx_neu_J_1.npy")
|
| 646 |
+
|
| 647 |
+
if self.J.shape[0] >= B:
|
| 648 |
+
J_temp = self.J[:B,:,:] #self.J = self.J[:B,:,:]
|
| 649 |
+
else:
|
| 650 |
+
J_temp = self.J[:1,:,:].repeat(B, 1, 1)
|
| 651 |
+
print("warning: self.J size 0 is lower than batchsize x seq_len")
|
| 652 |
+
|
| 653 |
+
parents = torch.Tensor(self.parents).long() # if self.parents is None else self.parents
|
| 654 |
+
J_transformed = self.batch_rigid_transform(rot_mats, J_temp, parents, dtype=torch.float32)
|
| 655 |
+
J_transformed += root_pos.unsqueeze(dim=1)
|
| 656 |
+
# J_transformed = J_transformed.detach().cpu().numpy()
|
| 657 |
+
|
| 658 |
+
return J_transformed
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
if __name__ == "__main__":
|
| 662 |
+
print("1")
|
| 663 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
smplx_fk = SMPLX_Skeleton(device = device, batch=150)
|
| 667 |
+
motion_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/dances/012_slice0.pkl"
|
| 668 |
+
# music_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/wavs/012_slice0.wav"
|
| 669 |
+
local_q_165, root_pos = smplx_fk.motion_data_load_process(motion_file)
|
| 670 |
+
print("local_q_165.shape", local_q_165.shape)
|
| 671 |
+
print("root_pos.shape", root_pos.shape)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
joints = smplx_fk.forward(local_q_165, root_pos).detach().cpu().numpy() # 150, 165 150, 3
|
| 675 |
+
|
| 676 |
+
print("joints.shape", joints.shape)
|
| 677 |
+
# skeleton_render(
|
| 678 |
+
# joints,
|
| 679 |
+
# epoch=f"e{1}_b{1}",
|
| 680 |
+
# out="./output/temp",
|
| 681 |
+
# name=music_file,
|
| 682 |
+
# render=True,
|
| 683 |
+
# stitch=False,
|
| 684 |
+
# sound=True,
|
| 685 |
+
# smpl_mode="smplx"
|
| 686 |
+
# )
|
| 687 |
+
|