thanks to shubham-goel ❤
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- slahmr/.gitignore +140 -0
- slahmr/.gitmodules +6 -0
- slahmr/LICENSE +21 -0
- slahmr/README.md +167 -0
- slahmr/download_models.sh +5 -0
- slahmr/env.yaml +45 -0
- slahmr/env_build.yaml +127 -0
- slahmr/install.sh +35 -0
- slahmr/requirements.txt +27 -0
- slahmr/setup.py +9 -0
- slahmr/slahmr.zip +3 -0
- slahmr/slahmr/__init__.py +0 -0
- slahmr/slahmr/body_model/__init__.py +3 -0
- slahmr/slahmr/body_model/body_model.py +142 -0
- slahmr/slahmr/body_model/specs.py +554 -0
- slahmr/slahmr/body_model/utils.py +56 -0
- slahmr/slahmr/confs/config.yaml +51 -0
- slahmr/slahmr/confs/data/3dpw.yaml +18 -0
- slahmr/slahmr/confs/data/3dpw_gt.yaml +18 -0
- slahmr/slahmr/confs/data/custom.yaml +17 -0
- slahmr/slahmr/confs/data/davis.yaml +16 -0
- slahmr/slahmr/confs/data/egobody.yaml +18 -0
- slahmr/slahmr/confs/data/posetrack.yaml +17 -0
- slahmr/slahmr/confs/data/video.yaml +24 -0
- slahmr/slahmr/confs/init.yaml +13 -0
- slahmr/slahmr/confs/optim.yaml +51 -0
- slahmr/slahmr/data/__init__.py +2 -0
- slahmr/slahmr/data/dataset.py +438 -0
- slahmr/slahmr/data/tools.py +108 -0
- slahmr/slahmr/data/vidproc.py +82 -0
- slahmr/slahmr/eval/__init__.py +0 -0
- slahmr/slahmr/eval/associate.py +161 -0
- slahmr/slahmr/eval/egobody_utils.py +171 -0
- slahmr/slahmr/eval/run_eval.py +289 -0
- slahmr/slahmr/eval/split_3dpw.py +99 -0
- slahmr/slahmr/eval/split_egobody.py +123 -0
- slahmr/slahmr/eval/tools.py +181 -0
- slahmr/slahmr/geometry/__init__.py +5 -0
- slahmr/slahmr/geometry/camera.py +348 -0
- slahmr/slahmr/geometry/mesh.py +110 -0
- slahmr/slahmr/geometry/pcl.py +60 -0
- slahmr/slahmr/geometry/plane.py +101 -0
- slahmr/slahmr/geometry/rotation.py +284 -0
- slahmr/slahmr/humor/__init__.py +0 -0
- slahmr/slahmr/humor/amass_utils.py +148 -0
- slahmr/slahmr/humor/humor_model.py +1655 -0
- slahmr/slahmr/humor/transforms.py +472 -0
- slahmr/slahmr/job_specs/3dpw_test_split.txt +248 -0
- slahmr/slahmr/job_specs/davis.txt +24 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
lietorch-0.2-py3.10-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
lietorch-0.2-py3.10-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
slahmr/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
slahmr/third-party/DROID-SLAM/build/lib.linux-x86_64-3.10/droid_backends.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
slahmr/third-party/DROID-SLAM/build/lib.linux-x86_64-3.10/lietorch_backends.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
slahmr/third-party/DROID-SLAM/build/temp.linux-x86_64-3.10/src/droid_kernels.o filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
slahmr/third-party/DROID-SLAM/build/temp.linux-x86_64-3.10/thirdparty/lietorch/lietorch/src/lietorch_gpu.o filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
slahmr/third-party/DROID-SLAM/thirdparty/lietorch/examples/registration/assets/registration.gif filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
slahmr/third-party/ViTPose/demo/resources/demo_coco.gif filter=lfs diff=lfs merge=lfs -text
|
slahmr/.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# data
|
| 7 |
+
*outputs*
|
| 8 |
+
*renders*
|
| 9 |
+
*cache*
|
| 10 |
+
*checkpoints*
|
| 11 |
+
*_DATA
|
| 12 |
+
|
| 13 |
+
*.swp
|
| 14 |
+
*.out
|
| 15 |
+
*.err
|
| 16 |
+
|
| 17 |
+
# C extensions
|
| 18 |
+
*.so
|
| 19 |
+
|
| 20 |
+
# Distribution / packaging
|
| 21 |
+
.Python
|
| 22 |
+
build/
|
| 23 |
+
develop-eggs/
|
| 24 |
+
dist/
|
| 25 |
+
downloads/
|
| 26 |
+
eggs/
|
| 27 |
+
.eggs/
|
| 28 |
+
lib/
|
| 29 |
+
lib64/
|
| 30 |
+
parts/
|
| 31 |
+
sdist/
|
| 32 |
+
var/
|
| 33 |
+
wheels/
|
| 34 |
+
pip-wheel-metadata/
|
| 35 |
+
share/python-wheels/
|
| 36 |
+
*.egg-info/
|
| 37 |
+
.installed.cfg
|
| 38 |
+
*.egg
|
| 39 |
+
MANIFEST
|
| 40 |
+
|
| 41 |
+
# PyInstaller
|
| 42 |
+
# Usually these files are written by a python script from a template
|
| 43 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 44 |
+
*.manifest
|
| 45 |
+
*.spec
|
| 46 |
+
|
| 47 |
+
# Installer logs
|
| 48 |
+
pip-log.txt
|
| 49 |
+
pip-delete-this-directory.txt
|
| 50 |
+
|
| 51 |
+
# Unit test / coverage reports
|
| 52 |
+
htmlcov/
|
| 53 |
+
.tox/
|
| 54 |
+
.nox/
|
| 55 |
+
.coverage
|
| 56 |
+
.coverage.*
|
| 57 |
+
.cache
|
| 58 |
+
nosetests.xml
|
| 59 |
+
coverage.xml
|
| 60 |
+
*.cover
|
| 61 |
+
*.py,cover
|
| 62 |
+
.hypothesis/
|
| 63 |
+
.pytest_cache/
|
| 64 |
+
|
| 65 |
+
# Translations
|
| 66 |
+
*.mo
|
| 67 |
+
*.pot
|
| 68 |
+
|
| 69 |
+
# Django stuff:
|
| 70 |
+
*.log
|
| 71 |
+
local_settings.py
|
| 72 |
+
db.sqlite3
|
| 73 |
+
db.sqlite3-journal
|
| 74 |
+
|
| 75 |
+
# Flask stuff:
|
| 76 |
+
instance/
|
| 77 |
+
.webassets-cache
|
| 78 |
+
|
| 79 |
+
# Scrapy stuff:
|
| 80 |
+
.scrapy
|
| 81 |
+
|
| 82 |
+
# Sphinx documentation
|
| 83 |
+
docs/_build/
|
| 84 |
+
|
| 85 |
+
# PyBuilder
|
| 86 |
+
target/
|
| 87 |
+
|
| 88 |
+
# Jupyter Notebook
|
| 89 |
+
.ipynb_checkpoints
|
| 90 |
+
|
| 91 |
+
# IPython
|
| 92 |
+
profile_default/
|
| 93 |
+
ipython_config.py
|
| 94 |
+
|
| 95 |
+
# pyenv
|
| 96 |
+
.python-version
|
| 97 |
+
|
| 98 |
+
# pipenv
|
| 99 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 100 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 101 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 102 |
+
# install all needed dependencies.
|
| 103 |
+
#Pipfile.lock
|
| 104 |
+
|
| 105 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 106 |
+
__pypackages__/
|
| 107 |
+
|
| 108 |
+
# Celery stuff
|
| 109 |
+
celerybeat-schedule
|
| 110 |
+
celerybeat.pid
|
| 111 |
+
|
| 112 |
+
# SageMath parsed files
|
| 113 |
+
*.sage.py
|
| 114 |
+
|
| 115 |
+
# Environments
|
| 116 |
+
.env
|
| 117 |
+
.venv
|
| 118 |
+
env/
|
| 119 |
+
venv/
|
| 120 |
+
ENV/
|
| 121 |
+
env.bak/
|
| 122 |
+
venv.bak/
|
| 123 |
+
|
| 124 |
+
# Spyder project settings
|
| 125 |
+
.spyderproject
|
| 126 |
+
.spyproject
|
| 127 |
+
|
| 128 |
+
# Rope project settings
|
| 129 |
+
.ropeproject
|
| 130 |
+
|
| 131 |
+
# mkdocs documentation
|
| 132 |
+
/site
|
| 133 |
+
|
| 134 |
+
# mypy
|
| 135 |
+
.mypy_cache/
|
| 136 |
+
.dmypy.json
|
| 137 |
+
dmypy.json
|
| 138 |
+
|
| 139 |
+
# Pyre type checker
|
| 140 |
+
.pyre/
|
slahmr/.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "third-party/DROID-SLAM"]
|
| 2 |
+
path = third-party/DROID-SLAM
|
| 3 |
+
url = https://github.com/princeton-vl/DROID-SLAM.git
|
| 4 |
+
[submodule "third-party/ViTPose"]
|
| 5 |
+
path = third-party/ViTPose
|
| 6 |
+
url = https://github.com/ViTAE-Transformer/ViTPose.git
|
slahmr/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 vye16
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
slahmr/README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Decoupling Human and Camera Motion from Videos in the Wild
|
| 2 |
+
|
| 3 |
+
Official PyTorch implementation of the paper Decoupling Human and Camera Motion from Videos in the Wild
|
| 4 |
+
|
| 5 |
+
[Project page](https://vye16.github.io/slahmr/) | [ArXiv](https://arxiv.org/abs/2302.12827)
|
| 6 |
+
|
| 7 |
+
<img src="./teaser.png">
|
| 8 |
+
|
| 9 |
+
## [<img src="https://i.imgur.com/QCojoJk.png" width="40"> You can run SLAHMR in Google Colab](https://colab.research.google.com/drive/1knzxW3XuxiaBH6hcwx01cs6DfA4azv5E?usp=sharing)
|
| 10 |
+
|
| 11 |
+
## News
|
| 12 |
+
|
| 13 |
+
- [2023/07] We updated the code to support tracking from [4D Humans](https://shubham-goel.github.io/4dhumans/)! The original code remains in the `release` branch.
|
| 14 |
+
- [2023/02] Original release!
|
| 15 |
+
|
| 16 |
+
## Getting started
|
| 17 |
+
This code was tested on Ubuntu 22.04 LTS and requires a CUDA-capable GPU.
|
| 18 |
+
|
| 19 |
+
1. Clone repository and submodules
|
| 20 |
+
```
|
| 21 |
+
git clone --recursive https://github.com/vye16/slahmr.git
|
| 22 |
+
```
|
| 23 |
+
or initialize submodules if already cloned
|
| 24 |
+
```
|
| 25 |
+
git submodule update --init --recursive
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
2. Set up conda environment. Run
|
| 29 |
+
```
|
| 30 |
+
source install.sh
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
<details>
|
| 34 |
+
<summary>We also include the following steps for trouble-shooting.</summary>
|
| 35 |
+
|
| 36 |
+
* Create environment
|
| 37 |
+
```
|
| 38 |
+
conda env create -f env.yaml
|
| 39 |
+
conda activate slahmr
|
| 40 |
+
```
|
| 41 |
+
We use PyTorch 1.13.0 with CUDA 11.7. Please modify according to your setup; we've tested successfully for PyTorch 1.11 as well.
|
| 42 |
+
We've also included `env_build.yaml` to speed up installation using already-solved dependencies, though it might not be compatible with your CUDA driver.
|
| 43 |
+
|
| 44 |
+
* Install PHALP
|
| 45 |
+
```
|
| 46 |
+
pip install phalp[all]@git+https://github.com/brjathu/PHALP.git
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
* Install current source repo
|
| 50 |
+
```
|
| 51 |
+
pip install -e .
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
* Install ViTPose
|
| 55 |
+
```
|
| 56 |
+
pip install -v -e third-party/ViTPose
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
* Install DROID-SLAM (will take a while)
|
| 60 |
+
```
|
| 61 |
+
cd third-party/DROID-SLAM
|
| 62 |
+
python setup.py install
|
| 63 |
+
```
|
| 64 |
+
</details>
|
| 65 |
+
|
| 66 |
+
3. Download models from [here](https://drive.google.com/file/d/1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW/view?usp=sharing). Run
|
| 67 |
+
```
|
| 68 |
+
./download_models.sh
|
| 69 |
+
```
|
| 70 |
+
or
|
| 71 |
+
```
|
| 72 |
+
gdown https://drive.google.com/uc?id=1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW
|
| 73 |
+
unzip -q slahmr_dependencies.zip
|
| 74 |
+
rm slahmr_dependencies.zip
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
All models and checkpoints should have been unpacked in `_DATA`.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
## Fitting to an RGB video:
|
| 81 |
+
For a custom video, you can edit the config file: `slahmr/confs/data/video.yaml`.
|
| 82 |
+
Then, from the `slahmr` directory, you can run:
|
| 83 |
+
```
|
| 84 |
+
python run_opt.py data=video run_opt=True run_vis=True
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
We use hydra to launch experiments, and all parameters can be found in `slahmr/confs/config.yaml`.
|
| 88 |
+
If you would like to update any aspect of logging or optimization tuning, update the relevant config files.
|
| 89 |
+
|
| 90 |
+
By default, we will log each run to `outputs/video-val/<DATE>/<VIDEO_NAME>`.
|
| 91 |
+
Each stage of optimization will produce a separate subdirectory, each of which will contain outputs saved throughout the optimization
|
| 92 |
+
and rendered videos of the final result for that stage of optimization.
|
| 93 |
+
The `motion_chunks` directory contains the outputs of the final stage of optimization,
|
| 94 |
+
`root_fit` and `smooth_fit` contain outputs of short, intermediate stages of optimization,
|
| 95 |
+
and `init` contains the initialized outputs before optimization.
|
| 96 |
+
|
| 97 |
+
We've provided a `run_vis.py` script for running visualization from logs after optimization.
|
| 98 |
+
From the `slahmr` directory, run
|
| 99 |
+
```
|
| 100 |
+
python run_vis.py --log_root <LOG_ROOT>
|
| 101 |
+
```
|
| 102 |
+
and it will visualize all log subdirectories in `<LOG_ROOT>`.
|
| 103 |
+
Each output npz file will contain the SMPL parameters for all optimized people, the camera intrinsics and extrinsics.
|
| 104 |
+
The `motion_chunks` output will contain additional predictions from the motion prior.
|
| 105 |
+
Please see `run_vis.py` for how to extract the people meshes from the output parameters.
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
## Fitting to specific datasets:
|
| 109 |
+
We provide configurations for dataset formats in `slahmr/confs/data`:
|
| 110 |
+
1. Posetrack in `slahmr/confs/data/posetrack.yaml`
|
| 111 |
+
2. Egobody in `slahmr/confs/data/egobody.yaml`
|
| 112 |
+
3. 3DPW in `slahmr/confs/data/3dpw.yaml`
|
| 113 |
+
4. Custom video in `slahmr/confs/data/video.yaml`
|
| 114 |
+
|
| 115 |
+
**Please make sure to update all paths to data in the config files.**
|
| 116 |
+
|
| 117 |
+
We include tools to both process existing datasets we evaluated on in the paper, and to process custom data and videos.
|
| 118 |
+
We include experiments from the paper on the Egobody, Posetrack, and 3DPW datasets.
|
| 119 |
+
|
| 120 |
+
If you want to run on a large number of videos, or if you want to select specific people tracks for optimization,
|
| 121 |
+
we recommend preprocesing in advance.
|
| 122 |
+
For a single downloaded video, there is no need to run preprocessing in advance.
|
| 123 |
+
|
| 124 |
+
From the `slahmr/preproc` directory, run PHALP on all your sequences
|
| 125 |
+
```
|
| 126 |
+
python launch_phalp.py --type <DATASET_TYPE> --root <DATASET_ROOT> --split <DATASET_SPLIT> --gpus <GPUS>
|
| 127 |
+
```
|
| 128 |
+
and run DROID-SLAM on all your sequences
|
| 129 |
+
```
|
| 130 |
+
python launch_slam.py --type <DATASET_TYPE> --root <DATASET_ROOT> --split <DATASET_SPLIT> --gpus <GPUS>
|
| 131 |
+
```
|
| 132 |
+
You can also update the paths to datasets in `slahmr/preproc/datasets.py` for repeated use.
|
| 133 |
+
|
| 134 |
+
Then, from the `slahmr` directory,
|
| 135 |
+
```
|
| 136 |
+
python run_opt.py data=<DATA_CFG> run_opt=True run_vis=True
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
We've provided a helper script `launch.py` for launching many optimization jobs in parallel.
|
| 140 |
+
You can specify job-specific arguments with a job spec file, such as the example files in `job_specs`,
|
| 141 |
+
and batch-specific arguments shared across all jobs as
|
| 142 |
+
```
|
| 143 |
+
python launch.py --gpus 1 2 -f job_specs/pt_val_shots.txt -s data=posetrack exp_name=posetrack_val
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## Evaluation on 3D datasets
|
| 147 |
+
After launching and completing optimization on either the Egobody or 3DPW datasets,
|
| 148 |
+
you can evaluate the outputs with scripts in the `eval` directory.
|
| 149 |
+
Before running, please update `EGOBODY_ROOT` and `TDPW_ROOT` in `eval/tools.py`.
|
| 150 |
+
Then, run
|
| 151 |
+
```
|
| 152 |
+
python run_eval.py -d <DSET_TYPE> -i <RES_ROOT> -f <JOB_FILE>
|
| 153 |
+
```
|
| 154 |
+
where `<JOB_FILE>` is the same job file used to launch all optimization runs.
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
## BibTeX
|
| 158 |
+
|
| 159 |
+
If you use our code in your research, please cite the following paper:
|
| 160 |
+
```
|
| 161 |
+
@inproceedings{ye2023slahmr,
|
| 162 |
+
title={Decoupling Human and Camera Motion from Videos in the Wild},
|
| 163 |
+
author={Ye, Vickie and Pavlakos, Georgios and Malik, Jitendra and Kanazawa, Angjoo},
|
| 164 |
+
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 165 |
+
month={June},
|
| 166 |
+
year={2023}
|
| 167 |
+
}
|
slahmr/download_models.sh
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
# download models
|
| 3 |
+
gdown https://drive.google.com/uc?id=1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW
|
| 4 |
+
unzip -q slahmr_dependencies.zip
|
| 5 |
+
rm slahmr_dependencies.zip
|
slahmr/env.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: slahmr
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- pytorch
|
| 5 |
+
- nvidia
|
| 6 |
+
- rusty1s
|
| 7 |
+
dependencies:
|
| 8 |
+
- python=3.9
|
| 9 |
+
- pytorch
|
| 10 |
+
- pytorch-cuda=11.7
|
| 11 |
+
- torchvision
|
| 12 |
+
- pytorch-scatter
|
| 13 |
+
- suitesparse
|
| 14 |
+
- pip
|
| 15 |
+
- pip:
|
| 16 |
+
- git+https://github.com/facebookresearch/detectron2.git
|
| 17 |
+
- git+https://github.com/brjathu/pytube.git
|
| 18 |
+
- git+https://github.com/nghorbani/configer
|
| 19 |
+
- setuptools==59.5.0
|
| 20 |
+
- torchgeometry==0.1.2
|
| 21 |
+
- tensorboard
|
| 22 |
+
- smplx
|
| 23 |
+
- pyrender
|
| 24 |
+
- open3d
|
| 25 |
+
- imageio-ffmpeg
|
| 26 |
+
- matplotlib
|
| 27 |
+
- opencv-python
|
| 28 |
+
- scipy
|
| 29 |
+
- scikit-image
|
| 30 |
+
- scikit-learn==0.22
|
| 31 |
+
- joblib
|
| 32 |
+
- cython
|
| 33 |
+
- tqdm
|
| 34 |
+
- hydra-core
|
| 35 |
+
- pyyaml
|
| 36 |
+
- chumpy
|
| 37 |
+
- gdown
|
| 38 |
+
- dill
|
| 39 |
+
- motmetrics
|
| 40 |
+
- scenedetect[opencv]
|
| 41 |
+
- einops
|
| 42 |
+
- mmcv==1.3.9
|
| 43 |
+
- timm==0.4.9
|
| 44 |
+
- xtcocotools==1.10
|
| 45 |
+
- pandas==1.4.0
|
slahmr/env_build.yaml
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: slahmr2
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- _libgcc_mutex=0.1=main
|
| 6 |
+
- _openmp_mutex=5.1=1_gnu
|
| 7 |
+
- bzip2=1.0.8=h7b6447c_0
|
| 8 |
+
- ca-certificates=2023.05.30=h06a4308_0
|
| 9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 10 |
+
- libffi=3.4.4=h6a678d5_0
|
| 11 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 12 |
+
- libgomp=11.2.0=h1234567_1
|
| 13 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 14 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 15 |
+
- ncurses=6.4=h6a678d5_0
|
| 16 |
+
- openssl=3.0.9=h7f8727e_0
|
| 17 |
+
- python=3.10.12=h955ad1f_0
|
| 18 |
+
- readline=8.2=h5eee18b_0
|
| 19 |
+
- sqlite=3.41.2=h5eee18b_0
|
| 20 |
+
- tk=8.6.12=h1ccaba5_0
|
| 21 |
+
- tzdata=2023c=h04d1e81_0
|
| 22 |
+
- xz=5.4.2=h5eee18b_0
|
| 23 |
+
- zlib=1.2.13=h5eee18b_0
|
| 24 |
+
- pip:
|
| 25 |
+
- addict==2.4.0
|
| 26 |
+
- ansi2html==1.8.0
|
| 27 |
+
- appdirs==1.4.4
|
| 28 |
+
- asttokens==2.2.1
|
| 29 |
+
- attrs==23.1.0
|
| 30 |
+
- av==10.0.0
|
| 31 |
+
- backcall==0.2.0
|
| 32 |
+
- beautifulsoup4==4.12.2
|
| 33 |
+
- certifi==2022.12.7
|
| 34 |
+
- charset-normalizer==2.1.1
|
| 35 |
+
- click==8.1.4
|
| 36 |
+
- comm==0.1.3
|
| 37 |
+
- configargparse==1.5.5
|
| 38 |
+
- configer==1.4.1
|
| 39 |
+
- configparser==5.3.0
|
| 40 |
+
- contourpy==1.1.0
|
| 41 |
+
- cycler==0.11.0
|
| 42 |
+
- cython==0.29.36
|
| 43 |
+
- dash==2.11.1
|
| 44 |
+
- dash-core-components==2.0.0
|
| 45 |
+
- dash-html-components==2.0.0
|
| 46 |
+
- dash-table==5.0.0
|
| 47 |
+
- debugpy==1.6.7
|
| 48 |
+
- decorator==5.1.1
|
| 49 |
+
- droid-backends==0.0.0
|
| 50 |
+
- executing==1.2.0
|
| 51 |
+
- fastjsonschema==2.17.1
|
| 52 |
+
- flask==2.2.5
|
| 53 |
+
- fonttools==4.40.0
|
| 54 |
+
- gdown==4.7.1
|
| 55 |
+
- hmr2==0.0.0
|
| 56 |
+
- idna==3.4
|
| 57 |
+
- imageio-ffmpeg==0.4.8
|
| 58 |
+
- importlib-metadata==6.7.0
|
| 59 |
+
- ipdb==0.13.13
|
| 60 |
+
- ipykernel==6.24.0
|
| 61 |
+
- ipython==8.14.0
|
| 62 |
+
- ipywidgets==8.0.7
|
| 63 |
+
- itsdangerous==2.1.2
|
| 64 |
+
- jedi==0.18.2
|
| 65 |
+
- jinja2==3.1.2
|
| 66 |
+
- json-tricks==3.17.1
|
| 67 |
+
- jsonschema==4.18.0
|
| 68 |
+
- jsonschema-specifications==2023.6.1
|
| 69 |
+
- jupyter-client==8.3.0
|
| 70 |
+
- jupyter-core==5.3.1
|
| 71 |
+
- jupyterlab-widgets==3.0.8
|
| 72 |
+
- kiwisolver==1.4.4
|
| 73 |
+
- lietorch==0.2
|
| 74 |
+
- matplotlib==3.7.2
|
| 75 |
+
- matplotlib-inline==0.1.6
|
| 76 |
+
- mmcv==1.3.9
|
| 77 |
+
- munkres==1.1.4
|
| 78 |
+
- nbformat==5.7.0
|
| 79 |
+
- nest-asyncio==1.5.6
|
| 80 |
+
- oauthlib==3.2.2
|
| 81 |
+
- open3d==0.17.0
|
| 82 |
+
- pandas==1.4.0
|
| 83 |
+
- parso==0.8.3
|
| 84 |
+
- pexpect==4.8.0
|
| 85 |
+
- phalp==0.1.3
|
| 86 |
+
- pickleshare==0.7.5
|
| 87 |
+
- pillow==9.3.0
|
| 88 |
+
- pip==23.1.2
|
| 89 |
+
- plotly==5.15.0
|
| 90 |
+
- prompt-toolkit==3.0.39
|
| 91 |
+
- psutil==5.9.5
|
| 92 |
+
- ptyprocess==0.7.0
|
| 93 |
+
- pure-eval==0.2.2
|
| 94 |
+
- pyasn1==0.5.0
|
| 95 |
+
- pyasn1-modules==0.3.0
|
| 96 |
+
- pyparsing==3.0.9
|
| 97 |
+
- pyquaternion==0.9.9
|
| 98 |
+
- pysocks==1.7.1
|
| 99 |
+
- pytz==2023.3
|
| 100 |
+
- pyyaml==6.0
|
| 101 |
+
- pyzmq==25.1.0
|
| 102 |
+
- referencing==0.29.1
|
| 103 |
+
- requests==2.28.1
|
| 104 |
+
- retrying==1.3.4
|
| 105 |
+
- rpds-py==0.8.8
|
| 106 |
+
- scikit-learn==1.3.0
|
| 107 |
+
- scipy==1.11.1
|
| 108 |
+
- setuptools==59.5.0
|
| 109 |
+
- six==1.16.0
|
| 110 |
+
- soupsieve==2.4.1
|
| 111 |
+
- stack-data==0.6.2
|
| 112 |
+
- tenacity==8.2.2
|
| 113 |
+
- timm==0.4.9
|
| 114 |
+
- torch==1.13.0+cu117
|
| 115 |
+
- torch-scatter==2.1.1+pt113cu117
|
| 116 |
+
- torchgeometry==0.1.2
|
| 117 |
+
- torchvision==0.14.0+cu117
|
| 118 |
+
- tornado==6.3.2
|
| 119 |
+
- traitlets==5.9.0
|
| 120 |
+
- urllib3==1.26.13
|
| 121 |
+
- wcwidth==0.2.6
|
| 122 |
+
- werkzeug==2.2.3
|
| 123 |
+
- wheel==0.38.4
|
| 124 |
+
- widgetsnbextension==4.0.8
|
| 125 |
+
- xtcocotools==1.13
|
| 126 |
+
- yapf==0.40.1
|
| 127 |
+
- zipp==3.15.0
|
slahmr/install.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
export CONDA_ENV_NAME=slahmr
|
| 5 |
+
|
| 6 |
+
conda create -n $CONDA_ENV_NAME python=3.10 -y
|
| 7 |
+
|
| 8 |
+
conda activate $CONDA_ENV_NAME
|
| 9 |
+
|
| 10 |
+
# install pytorch using pip, update with appropriate cuda drivers if necessary
|
| 11 |
+
pip install torch==1.13.0 torchvision==0.14.0 --index-url https://download.pytorch.org/whl/cu117
|
| 12 |
+
# uncomment if pip installation isn't working
|
| 13 |
+
# conda install pytorch=1.13.0 torchvision=0.14.0 pytorch-cuda=11.7 -c pytorch -c nvidia -y
|
| 14 |
+
|
| 15 |
+
# install pytorch scatter using pip, update with appropriate cuda drivers if necessary
|
| 16 |
+
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cu117.html
|
| 17 |
+
# uncomment if pip installation isn't working
|
| 18 |
+
# conda install pytorch-scatter -c pyg -y
|
| 19 |
+
|
| 20 |
+
# install PHALP
|
| 21 |
+
pip install phalp[all]@git+https://github.com/brjathu/PHALP.git
|
| 22 |
+
|
| 23 |
+
# install remaining requirements
|
| 24 |
+
pip install -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# install source
|
| 27 |
+
pip install -e .
|
| 28 |
+
|
| 29 |
+
# install ViTPose
|
| 30 |
+
pip install -v -e third-party/ViTPose
|
| 31 |
+
|
| 32 |
+
# install DROID-SLAM
|
| 33 |
+
cd third-party/DROID-SLAM
|
| 34 |
+
python setup.py install
|
| 35 |
+
cd ../..
|
slahmr/requirements.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/nghorbani/configer
|
| 2 |
+
setuptools==59.5.0
|
| 3 |
+
torchgeometry==0.1.2
|
| 4 |
+
tensorboard
|
| 5 |
+
numpy==1.23
|
| 6 |
+
smplx
|
| 7 |
+
pyrender
|
| 8 |
+
open3d
|
| 9 |
+
imageio-ffmpeg
|
| 10 |
+
matplotlib
|
| 11 |
+
opencv-python
|
| 12 |
+
scipy
|
| 13 |
+
scikit-image
|
| 14 |
+
joblib
|
| 15 |
+
cython
|
| 16 |
+
tqdm
|
| 17 |
+
hydra-core
|
| 18 |
+
pyyaml
|
| 19 |
+
chumpy
|
| 20 |
+
gdown
|
| 21 |
+
dill
|
| 22 |
+
motmetrics
|
| 23 |
+
einops
|
| 24 |
+
mmcv==1.3.9
|
| 25 |
+
timm==0.4.9
|
| 26 |
+
xtcocotools
|
| 27 |
+
pandas==1.4.0
|
slahmr/setup.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="slahmr",
|
| 5 |
+
packages=find_packages(
|
| 6 |
+
where="slahmr",
|
| 7 |
+
),
|
| 8 |
+
package_dir={"": "slahmr"},
|
| 9 |
+
)
|
slahmr/slahmr.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e59655cc88dc71f3a9ac20425310bf996da3b6f15619497d17c67e7b60a671a
|
| 3 |
+
size 5662244525
|
slahmr/slahmr/__init__.py
ADDED
|
File without changes
|
slahmr/slahmr/body_model/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .body_model import *
|
| 2 |
+
from .specs import *
|
| 3 |
+
from .utils import *
|
slahmr/slahmr/body_model/body_model.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from smplx import SMPL, SMPLH, SMPLX
|
| 7 |
+
from smplx.vertex_ids import vertex_ids
|
| 8 |
+
from smplx.utils import Struct
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BodyModel(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Wrapper around SMPLX body model class.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
bm_path,
|
| 19 |
+
num_betas=10,
|
| 20 |
+
batch_size=1,
|
| 21 |
+
num_expressions=10,
|
| 22 |
+
use_vtx_selector=False,
|
| 23 |
+
model_type="smplh",
|
| 24 |
+
kid_template_path=None,
|
| 25 |
+
):
|
| 26 |
+
super(BodyModel, self).__init__()
|
| 27 |
+
"""
|
| 28 |
+
Creates the body model object at the given path.
|
| 29 |
+
|
| 30 |
+
:param bm_path: path to the body model pkl file
|
| 31 |
+
:param num_expressions: only for smplx
|
| 32 |
+
:param model_type: one of [smpl, smplh, smplx]
|
| 33 |
+
:param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints
|
| 34 |
+
"""
|
| 35 |
+
self.use_vtx_selector = use_vtx_selector
|
| 36 |
+
cur_vertex_ids = None
|
| 37 |
+
if self.use_vtx_selector:
|
| 38 |
+
cur_vertex_ids = vertex_ids[model_type]
|
| 39 |
+
data_struct = None
|
| 40 |
+
if ".npz" in bm_path:
|
| 41 |
+
# smplx does not support .npz by default, so have to load in manually
|
| 42 |
+
smpl_dict = np.load(bm_path, encoding="latin1")
|
| 43 |
+
data_struct = Struct(**smpl_dict)
|
| 44 |
+
# print(smpl_dict.files)
|
| 45 |
+
if model_type == "smplh":
|
| 46 |
+
data_struct.hands_componentsl = np.zeros((0))
|
| 47 |
+
data_struct.hands_componentsr = np.zeros((0))
|
| 48 |
+
data_struct.hands_meanl = np.zeros((15 * 3))
|
| 49 |
+
data_struct.hands_meanr = np.zeros((15 * 3))
|
| 50 |
+
V, D, B = data_struct.shapedirs.shape
|
| 51 |
+
data_struct.shapedirs = np.concatenate(
|
| 52 |
+
[data_struct.shapedirs, np.zeros((V, D, SMPL.SHAPE_SPACE_DIM - B))],
|
| 53 |
+
axis=-1,
|
| 54 |
+
) # super hacky way to let smplh use 16-size beta
|
| 55 |
+
kwargs = {
|
| 56 |
+
"model_type": model_type,
|
| 57 |
+
"data_struct": data_struct,
|
| 58 |
+
"num_betas": num_betas,
|
| 59 |
+
"batch_size": batch_size,
|
| 60 |
+
"num_expression_coeffs": num_expressions,
|
| 61 |
+
"vertex_ids": cur_vertex_ids,
|
| 62 |
+
"use_pca": False,
|
| 63 |
+
"flat_hand_mean": False,
|
| 64 |
+
}
|
| 65 |
+
if kid_template_path is not None:
|
| 66 |
+
kwargs["kid_template_path"] = kid_template_path
|
| 67 |
+
kwargs["age"] = "kid"
|
| 68 |
+
|
| 69 |
+
assert model_type in ["smpl", "smplh", "smplx"]
|
| 70 |
+
if model_type == "smpl":
|
| 71 |
+
self.bm = SMPL(bm_path, **kwargs)
|
| 72 |
+
self.num_joints = SMPL.NUM_JOINTS
|
| 73 |
+
elif model_type == "smplh":
|
| 74 |
+
self.bm = SMPLH(bm_path, **kwargs)
|
| 75 |
+
self.num_joints = SMPLH.NUM_JOINTS
|
| 76 |
+
elif model_type == "smplx":
|
| 77 |
+
self.bm = SMPLX(bm_path, **kwargs)
|
| 78 |
+
self.num_joints = SMPLX.NUM_JOINTS
|
| 79 |
+
|
| 80 |
+
self.model_type = model_type
|
| 81 |
+
|
| 82 |
+
def forward(
|
| 83 |
+
self,
|
| 84 |
+
root_orient=None,
|
| 85 |
+
pose_body=None,
|
| 86 |
+
pose_hand=None,
|
| 87 |
+
pose_jaw=None,
|
| 88 |
+
pose_eye=None,
|
| 89 |
+
betas=None,
|
| 90 |
+
trans=None,
|
| 91 |
+
dmpls=None,
|
| 92 |
+
expression=None,
|
| 93 |
+
return_dict=False,
|
| 94 |
+
**kwargs
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Note dmpls are not supported.
|
| 98 |
+
"""
|
| 99 |
+
assert dmpls is None
|
| 100 |
+
out_obj = self.bm(
|
| 101 |
+
betas=betas,
|
| 102 |
+
global_orient=root_orient,
|
| 103 |
+
body_pose=pose_body,
|
| 104 |
+
left_hand_pose=None
|
| 105 |
+
if pose_hand is None
|
| 106 |
+
else pose_hand[:, : (SMPLH.NUM_HAND_JOINTS * 3)],
|
| 107 |
+
right_hand_pose=None
|
| 108 |
+
if pose_hand is None
|
| 109 |
+
else pose_hand[:, (SMPLH.NUM_HAND_JOINTS * 3) :],
|
| 110 |
+
transl=trans,
|
| 111 |
+
expression=expression,
|
| 112 |
+
jaw_pose=pose_jaw,
|
| 113 |
+
leye_pose=None if pose_eye is None else pose_eye[:, :3],
|
| 114 |
+
reye_pose=None if pose_eye is None else pose_eye[:, 3:],
|
| 115 |
+
return_full_pose=True,
|
| 116 |
+
**kwargs
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
out = {
|
| 120 |
+
"v": out_obj.vertices,
|
| 121 |
+
"f": self.bm.faces_tensor,
|
| 122 |
+
"betas": out_obj.betas,
|
| 123 |
+
"Jtr": out_obj.joints,
|
| 124 |
+
"pose_body": out_obj.body_pose,
|
| 125 |
+
"full_pose": out_obj.full_pose,
|
| 126 |
+
}
|
| 127 |
+
if self.model_type in ["smplh", "smplx"]:
|
| 128 |
+
out["pose_hand"] = torch.cat(
|
| 129 |
+
[out_obj.left_hand_pose, out_obj.right_hand_pose], dim=-1
|
| 130 |
+
)
|
| 131 |
+
if self.model_type == "smplx":
|
| 132 |
+
out["pose_jaw"] = out_obj.jaw_pose
|
| 133 |
+
out["pose_eye"] = pose_eye
|
| 134 |
+
|
| 135 |
+
if not self.use_vtx_selector:
|
| 136 |
+
# don't need extra joints
|
| 137 |
+
out["Jtr"] = out["Jtr"][:, : self.num_joints + 1] # add one for the root
|
| 138 |
+
|
| 139 |
+
if not return_dict:
|
| 140 |
+
out = Struct(**out)
|
| 141 |
+
|
| 142 |
+
return out
|
slahmr/slahmr/body_model/specs.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
SMPL_JOINTS = {
|
| 4 |
+
"hips": 0,
|
| 5 |
+
"leftUpLeg": 1,
|
| 6 |
+
"rightUpLeg": 2,
|
| 7 |
+
"spine": 3,
|
| 8 |
+
"leftLeg": 4,
|
| 9 |
+
"rightLeg": 5,
|
| 10 |
+
"spine1": 6,
|
| 11 |
+
"leftFoot": 7,
|
| 12 |
+
"rightFoot": 8,
|
| 13 |
+
"spine2": 9,
|
| 14 |
+
"leftToeBase": 10,
|
| 15 |
+
"rightToeBase": 11,
|
| 16 |
+
"neck": 12,
|
| 17 |
+
"leftShoulder": 13,
|
| 18 |
+
"rightShoulder": 14,
|
| 19 |
+
"head": 15,
|
| 20 |
+
"leftArm": 16,
|
| 21 |
+
"rightArm": 17,
|
| 22 |
+
"leftForeArm": 18,
|
| 23 |
+
"rightForeArm": 19,
|
| 24 |
+
"leftHand": 20,
|
| 25 |
+
"rightHand": 21,
|
| 26 |
+
}
|
| 27 |
+
SMPL_PARENTS = [
|
| 28 |
+
-1,
|
| 29 |
+
0,
|
| 30 |
+
0,
|
| 31 |
+
0,
|
| 32 |
+
1,
|
| 33 |
+
2,
|
| 34 |
+
3,
|
| 35 |
+
4,
|
| 36 |
+
5,
|
| 37 |
+
6,
|
| 38 |
+
7,
|
| 39 |
+
8,
|
| 40 |
+
9,
|
| 41 |
+
12,
|
| 42 |
+
12,
|
| 43 |
+
12,
|
| 44 |
+
13,
|
| 45 |
+
14,
|
| 46 |
+
16,
|
| 47 |
+
17,
|
| 48 |
+
18,
|
| 49 |
+
19,
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
SMPLH_PATH = "./body_models/smplh"
|
| 53 |
+
SMPLX_PATH = "./body_models/smplx"
|
| 54 |
+
SMPL_PATH = "./body_models/smpl"
|
| 55 |
+
VPOSER_PATH = "./body_models/vposer_v1_0"
|
| 56 |
+
|
| 57 |
+
# chosen virtual mocap markers that are "keypoints" to work with
|
| 58 |
+
KEYPT_VERTS = [
|
| 59 |
+
4404,
|
| 60 |
+
920,
|
| 61 |
+
3076,
|
| 62 |
+
3169,
|
| 63 |
+
823,
|
| 64 |
+
4310,
|
| 65 |
+
1010,
|
| 66 |
+
1085,
|
| 67 |
+
4495,
|
| 68 |
+
4569,
|
| 69 |
+
6615,
|
| 70 |
+
3217,
|
| 71 |
+
3313,
|
| 72 |
+
6713,
|
| 73 |
+
6785,
|
| 74 |
+
3383,
|
| 75 |
+
6607,
|
| 76 |
+
3207,
|
| 77 |
+
1241,
|
| 78 |
+
1508,
|
| 79 |
+
4797,
|
| 80 |
+
4122,
|
| 81 |
+
1618,
|
| 82 |
+
1569,
|
| 83 |
+
5135,
|
| 84 |
+
5040,
|
| 85 |
+
5691,
|
| 86 |
+
5636,
|
| 87 |
+
5404,
|
| 88 |
+
2230,
|
| 89 |
+
2173,
|
| 90 |
+
2108,
|
| 91 |
+
134,
|
| 92 |
+
3645,
|
| 93 |
+
6543,
|
| 94 |
+
3123,
|
| 95 |
+
3024,
|
| 96 |
+
4194,
|
| 97 |
+
1306,
|
| 98 |
+
182,
|
| 99 |
+
3694,
|
| 100 |
+
4294,
|
| 101 |
+
744,
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
Openpose
|
| 107 |
+
"""
|
| 108 |
+
OP_NUM_JOINTS = 25
|
| 109 |
+
# OP_IGNORE_JOINTS = [1, 9, 12] # neck and left/right hip
|
| 110 |
+
OP_IGNORE_JOINTS = [1] # neck
|
| 111 |
+
OP_EDGE_LIST = [
|
| 112 |
+
[1, 8],
|
| 113 |
+
[1, 2],
|
| 114 |
+
[1, 5],
|
| 115 |
+
[2, 3],
|
| 116 |
+
[3, 4],
|
| 117 |
+
[5, 6],
|
| 118 |
+
[6, 7],
|
| 119 |
+
[8, 9],
|
| 120 |
+
[9, 10],
|
| 121 |
+
[10, 11],
|
| 122 |
+
[8, 12],
|
| 123 |
+
[12, 13],
|
| 124 |
+
[13, 14],
|
| 125 |
+
[1, 0],
|
| 126 |
+
[0, 15],
|
| 127 |
+
[15, 17],
|
| 128 |
+
[0, 16],
|
| 129 |
+
[16, 18],
|
| 130 |
+
[14, 19],
|
| 131 |
+
[19, 20],
|
| 132 |
+
[14, 21],
|
| 133 |
+
[11, 22],
|
| 134 |
+
[22, 23],
|
| 135 |
+
[11, 24],
|
| 136 |
+
]
|
| 137 |
+
# indices to map an openpose detection to its flipped version
|
| 138 |
+
OP_FLIP_MAP = [
|
| 139 |
+
0,
|
| 140 |
+
1,
|
| 141 |
+
5,
|
| 142 |
+
6,
|
| 143 |
+
7,
|
| 144 |
+
2,
|
| 145 |
+
3,
|
| 146 |
+
4,
|
| 147 |
+
8,
|
| 148 |
+
12,
|
| 149 |
+
13,
|
| 150 |
+
14,
|
| 151 |
+
9,
|
| 152 |
+
10,
|
| 153 |
+
11,
|
| 154 |
+
16,
|
| 155 |
+
15,
|
| 156 |
+
18,
|
| 157 |
+
17,
|
| 158 |
+
22,
|
| 159 |
+
23,
|
| 160 |
+
24,
|
| 161 |
+
19,
|
| 162 |
+
20,
|
| 163 |
+
21,
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
#
|
| 168 |
+
# From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py
|
| 169 |
+
# Please see license for usage restrictions.
|
| 170 |
+
#
|
| 171 |
+
def smpl_to_openpose(
|
| 172 |
+
model_type="smplx",
|
| 173 |
+
use_hands=True,
|
| 174 |
+
use_face=True,
|
| 175 |
+
use_face_contour=False,
|
| 176 |
+
openpose_format="coco25",
|
| 177 |
+
):
|
| 178 |
+
"""Returns the indices of the permutation that maps SMPL to OpenPose
|
| 179 |
+
|
| 180 |
+
Parameters
|
| 181 |
+
----------
|
| 182 |
+
model_type: str, optional
|
| 183 |
+
The type of SMPL-like model that is used. The default mapping
|
| 184 |
+
returned is for the SMPLX model
|
| 185 |
+
use_hands: bool, optional
|
| 186 |
+
Flag for adding to the returned permutation the mapping for the
|
| 187 |
+
hand keypoints. Defaults to True
|
| 188 |
+
use_face: bool, optional
|
| 189 |
+
Flag for adding to the returned permutation the mapping for the
|
| 190 |
+
face keypoints. Defaults to True
|
| 191 |
+
use_face_contour: bool, optional
|
| 192 |
+
Flag for appending the facial contour keypoints. Defaults to False
|
| 193 |
+
openpose_format: bool, optional
|
| 194 |
+
The output format of OpenPose. For now only COCO-25 and COCO-19 is
|
| 195 |
+
supported. Defaults to 'coco25'
|
| 196 |
+
|
| 197 |
+
"""
|
| 198 |
+
if openpose_format.lower() == "coco25":
|
| 199 |
+
if model_type == "smpl":
|
| 200 |
+
return np.array(
|
| 201 |
+
[
|
| 202 |
+
24,
|
| 203 |
+
12,
|
| 204 |
+
17,
|
| 205 |
+
19,
|
| 206 |
+
21,
|
| 207 |
+
16,
|
| 208 |
+
18,
|
| 209 |
+
20,
|
| 210 |
+
0,
|
| 211 |
+
2,
|
| 212 |
+
5,
|
| 213 |
+
8,
|
| 214 |
+
1,
|
| 215 |
+
4,
|
| 216 |
+
7,
|
| 217 |
+
25,
|
| 218 |
+
26,
|
| 219 |
+
27,
|
| 220 |
+
28,
|
| 221 |
+
29,
|
| 222 |
+
30,
|
| 223 |
+
31,
|
| 224 |
+
32,
|
| 225 |
+
33,
|
| 226 |
+
34,
|
| 227 |
+
],
|
| 228 |
+
dtype=np.int32,
|
| 229 |
+
)
|
| 230 |
+
elif model_type == "smplh":
|
| 231 |
+
body_mapping = np.array(
|
| 232 |
+
[
|
| 233 |
+
52,
|
| 234 |
+
12,
|
| 235 |
+
17,
|
| 236 |
+
19,
|
| 237 |
+
21,
|
| 238 |
+
16,
|
| 239 |
+
18,
|
| 240 |
+
20,
|
| 241 |
+
0,
|
| 242 |
+
2,
|
| 243 |
+
5,
|
| 244 |
+
8,
|
| 245 |
+
1,
|
| 246 |
+
4,
|
| 247 |
+
7,
|
| 248 |
+
53,
|
| 249 |
+
54,
|
| 250 |
+
55,
|
| 251 |
+
56,
|
| 252 |
+
57,
|
| 253 |
+
58,
|
| 254 |
+
59,
|
| 255 |
+
60,
|
| 256 |
+
61,
|
| 257 |
+
62,
|
| 258 |
+
],
|
| 259 |
+
dtype=np.int32,
|
| 260 |
+
)
|
| 261 |
+
mapping = [body_mapping]
|
| 262 |
+
if use_hands:
|
| 263 |
+
lhand_mapping = np.array(
|
| 264 |
+
[
|
| 265 |
+
20,
|
| 266 |
+
34,
|
| 267 |
+
35,
|
| 268 |
+
36,
|
| 269 |
+
63,
|
| 270 |
+
22,
|
| 271 |
+
23,
|
| 272 |
+
24,
|
| 273 |
+
64,
|
| 274 |
+
25,
|
| 275 |
+
26,
|
| 276 |
+
27,
|
| 277 |
+
65,
|
| 278 |
+
31,
|
| 279 |
+
32,
|
| 280 |
+
33,
|
| 281 |
+
66,
|
| 282 |
+
28,
|
| 283 |
+
29,
|
| 284 |
+
30,
|
| 285 |
+
67,
|
| 286 |
+
],
|
| 287 |
+
dtype=np.int32,
|
| 288 |
+
)
|
| 289 |
+
rhand_mapping = np.array(
|
| 290 |
+
[
|
| 291 |
+
21,
|
| 292 |
+
49,
|
| 293 |
+
50,
|
| 294 |
+
51,
|
| 295 |
+
68,
|
| 296 |
+
37,
|
| 297 |
+
38,
|
| 298 |
+
39,
|
| 299 |
+
69,
|
| 300 |
+
40,
|
| 301 |
+
41,
|
| 302 |
+
42,
|
| 303 |
+
70,
|
| 304 |
+
46,
|
| 305 |
+
47,
|
| 306 |
+
48,
|
| 307 |
+
71,
|
| 308 |
+
43,
|
| 309 |
+
44,
|
| 310 |
+
45,
|
| 311 |
+
72,
|
| 312 |
+
],
|
| 313 |
+
dtype=np.int32,
|
| 314 |
+
)
|
| 315 |
+
mapping += [lhand_mapping, rhand_mapping]
|
| 316 |
+
return np.concatenate(mapping)
|
| 317 |
+
# SMPLX
|
| 318 |
+
elif model_type == "smplx":
|
| 319 |
+
body_mapping = np.array(
|
| 320 |
+
[
|
| 321 |
+
55,
|
| 322 |
+
12,
|
| 323 |
+
17,
|
| 324 |
+
19,
|
| 325 |
+
21,
|
| 326 |
+
16,
|
| 327 |
+
18,
|
| 328 |
+
20,
|
| 329 |
+
0,
|
| 330 |
+
2,
|
| 331 |
+
5,
|
| 332 |
+
8,
|
| 333 |
+
1,
|
| 334 |
+
4,
|
| 335 |
+
7,
|
| 336 |
+
56,
|
| 337 |
+
57,
|
| 338 |
+
58,
|
| 339 |
+
59,
|
| 340 |
+
60,
|
| 341 |
+
61,
|
| 342 |
+
62,
|
| 343 |
+
63,
|
| 344 |
+
64,
|
| 345 |
+
65,
|
| 346 |
+
],
|
| 347 |
+
dtype=np.int32,
|
| 348 |
+
)
|
| 349 |
+
mapping = [body_mapping]
|
| 350 |
+
if use_hands:
|
| 351 |
+
lhand_mapping = np.array(
|
| 352 |
+
[
|
| 353 |
+
20,
|
| 354 |
+
37,
|
| 355 |
+
38,
|
| 356 |
+
39,
|
| 357 |
+
66,
|
| 358 |
+
25,
|
| 359 |
+
26,
|
| 360 |
+
27,
|
| 361 |
+
67,
|
| 362 |
+
28,
|
| 363 |
+
29,
|
| 364 |
+
30,
|
| 365 |
+
68,
|
| 366 |
+
34,
|
| 367 |
+
35,
|
| 368 |
+
36,
|
| 369 |
+
69,
|
| 370 |
+
31,
|
| 371 |
+
32,
|
| 372 |
+
33,
|
| 373 |
+
70,
|
| 374 |
+
],
|
| 375 |
+
dtype=np.int32,
|
| 376 |
+
)
|
| 377 |
+
rhand_mapping = np.array(
|
| 378 |
+
[
|
| 379 |
+
21,
|
| 380 |
+
52,
|
| 381 |
+
53,
|
| 382 |
+
54,
|
| 383 |
+
71,
|
| 384 |
+
40,
|
| 385 |
+
41,
|
| 386 |
+
42,
|
| 387 |
+
72,
|
| 388 |
+
43,
|
| 389 |
+
44,
|
| 390 |
+
45,
|
| 391 |
+
73,
|
| 392 |
+
49,
|
| 393 |
+
50,
|
| 394 |
+
51,
|
| 395 |
+
74,
|
| 396 |
+
46,
|
| 397 |
+
47,
|
| 398 |
+
48,
|
| 399 |
+
75,
|
| 400 |
+
],
|
| 401 |
+
dtype=np.int32,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
mapping += [lhand_mapping, rhand_mapping]
|
| 405 |
+
if use_face:
|
| 406 |
+
# end_idx = 127 + 17 * use_face_contour
|
| 407 |
+
face_mapping = np.arange(
|
| 408 |
+
76, 127 + 17 * use_face_contour, dtype=np.int32
|
| 409 |
+
)
|
| 410 |
+
mapping += [face_mapping]
|
| 411 |
+
|
| 412 |
+
return np.concatenate(mapping)
|
| 413 |
+
else:
|
| 414 |
+
raise ValueError("Unknown model type: {}".format(model_type))
|
| 415 |
+
elif openpose_format == "coco19":
|
| 416 |
+
if model_type == "smpl":
|
| 417 |
+
return np.array(
|
| 418 |
+
[24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28],
|
| 419 |
+
dtype=np.int32,
|
| 420 |
+
)
|
| 421 |
+
elif model_type == "smplh":
|
| 422 |
+
body_mapping = np.array(
|
| 423 |
+
[52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 53, 54, 55, 56],
|
| 424 |
+
dtype=np.int32,
|
| 425 |
+
)
|
| 426 |
+
mapping = [body_mapping]
|
| 427 |
+
if use_hands:
|
| 428 |
+
lhand_mapping = np.array(
|
| 429 |
+
[
|
| 430 |
+
20,
|
| 431 |
+
34,
|
| 432 |
+
35,
|
| 433 |
+
36,
|
| 434 |
+
57,
|
| 435 |
+
22,
|
| 436 |
+
23,
|
| 437 |
+
24,
|
| 438 |
+
58,
|
| 439 |
+
25,
|
| 440 |
+
26,
|
| 441 |
+
27,
|
| 442 |
+
59,
|
| 443 |
+
31,
|
| 444 |
+
32,
|
| 445 |
+
33,
|
| 446 |
+
60,
|
| 447 |
+
28,
|
| 448 |
+
29,
|
| 449 |
+
30,
|
| 450 |
+
61,
|
| 451 |
+
],
|
| 452 |
+
dtype=np.int32,
|
| 453 |
+
)
|
| 454 |
+
rhand_mapping = np.array(
|
| 455 |
+
[
|
| 456 |
+
21,
|
| 457 |
+
49,
|
| 458 |
+
50,
|
| 459 |
+
51,
|
| 460 |
+
62,
|
| 461 |
+
37,
|
| 462 |
+
38,
|
| 463 |
+
39,
|
| 464 |
+
63,
|
| 465 |
+
40,
|
| 466 |
+
41,
|
| 467 |
+
42,
|
| 468 |
+
64,
|
| 469 |
+
46,
|
| 470 |
+
47,
|
| 471 |
+
48,
|
| 472 |
+
65,
|
| 473 |
+
43,
|
| 474 |
+
44,
|
| 475 |
+
45,
|
| 476 |
+
66,
|
| 477 |
+
],
|
| 478 |
+
dtype=np.int32,
|
| 479 |
+
)
|
| 480 |
+
mapping += [lhand_mapping, rhand_mapping]
|
| 481 |
+
return np.concatenate(mapping)
|
| 482 |
+
# SMPLX
|
| 483 |
+
elif model_type == "smplx":
|
| 484 |
+
body_mapping = np.array(
|
| 485 |
+
[55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 56, 57, 58, 59],
|
| 486 |
+
dtype=np.int32,
|
| 487 |
+
)
|
| 488 |
+
mapping = [body_mapping]
|
| 489 |
+
if use_hands:
|
| 490 |
+
lhand_mapping = np.array(
|
| 491 |
+
[
|
| 492 |
+
20,
|
| 493 |
+
37,
|
| 494 |
+
38,
|
| 495 |
+
39,
|
| 496 |
+
60,
|
| 497 |
+
25,
|
| 498 |
+
26,
|
| 499 |
+
27,
|
| 500 |
+
61,
|
| 501 |
+
28,
|
| 502 |
+
29,
|
| 503 |
+
30,
|
| 504 |
+
62,
|
| 505 |
+
34,
|
| 506 |
+
35,
|
| 507 |
+
36,
|
| 508 |
+
63,
|
| 509 |
+
31,
|
| 510 |
+
32,
|
| 511 |
+
33,
|
| 512 |
+
64,
|
| 513 |
+
],
|
| 514 |
+
dtype=np.int32,
|
| 515 |
+
)
|
| 516 |
+
rhand_mapping = np.array(
|
| 517 |
+
[
|
| 518 |
+
21,
|
| 519 |
+
52,
|
| 520 |
+
53,
|
| 521 |
+
54,
|
| 522 |
+
65,
|
| 523 |
+
40,
|
| 524 |
+
41,
|
| 525 |
+
42,
|
| 526 |
+
66,
|
| 527 |
+
43,
|
| 528 |
+
44,
|
| 529 |
+
45,
|
| 530 |
+
67,
|
| 531 |
+
49,
|
| 532 |
+
50,
|
| 533 |
+
51,
|
| 534 |
+
68,
|
| 535 |
+
46,
|
| 536 |
+
47,
|
| 537 |
+
48,
|
| 538 |
+
69,
|
| 539 |
+
],
|
| 540 |
+
dtype=np.int32,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
mapping += [lhand_mapping, rhand_mapping]
|
| 544 |
+
if use_face:
|
| 545 |
+
face_mapping = np.arange(
|
| 546 |
+
70, 70 + 51 + 17 * use_face_contour, dtype=np.int32
|
| 547 |
+
)
|
| 548 |
+
mapping += [face_mapping]
|
| 549 |
+
|
| 550 |
+
return np.concatenate(mapping)
|
| 551 |
+
else:
|
| 552 |
+
raise ValueError("Unknown model type: {}".format(model_type))
|
| 553 |
+
else:
|
| 554 |
+
raise ValueError("Unknown joint format: {}".format(openpose_format))
|
slahmr/slahmr/body_model/utils.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .specs import SMPL_JOINTS
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def run_smpl(body_model, trans, root_orient, body_pose, betas=None):
|
| 6 |
+
"""
|
| 7 |
+
Forward pass of the SMPL model and populates pred_data accordingly with
|
| 8 |
+
joints3d, verts3d, points3d.
|
| 9 |
+
|
| 10 |
+
trans : B x T x 3
|
| 11 |
+
root_orient : B x T x 3
|
| 12 |
+
body_pose : B x T x J*3
|
| 13 |
+
betas : (optional) B x D
|
| 14 |
+
"""
|
| 15 |
+
B, T, _ = trans.shape
|
| 16 |
+
bm_batch_size = body_model.bm.batch_size
|
| 17 |
+
assert bm_batch_size % B == 0
|
| 18 |
+
seq_len = bm_batch_size // B
|
| 19 |
+
bm_num_betas = body_model.bm.num_betas
|
| 20 |
+
J_BODY = len(SMPL_JOINTS) - 1 # all joints except root
|
| 21 |
+
if T == 1:
|
| 22 |
+
# must expand to use with body model
|
| 23 |
+
trans = trans.expand(B, seq_len, 3)
|
| 24 |
+
root_orient = root_orient.expand(B, seq_len, 3)
|
| 25 |
+
body_pose = body_pose.expand(B, seq_len, J_BODY * 3)
|
| 26 |
+
elif T != seq_len:
|
| 27 |
+
trans, root_orient, body_pose = zero_pad_tensors(
|
| 28 |
+
[trans, root_orient, body_pose], seq_len - T
|
| 29 |
+
)
|
| 30 |
+
if betas is None:
|
| 31 |
+
betas = torch.zeros(B, bm_num_betas, device=trans.device)
|
| 32 |
+
betas = betas.reshape((B, 1, bm_num_betas)).expand((B, seq_len, bm_num_betas))
|
| 33 |
+
smpl_body = body_model(
|
| 34 |
+
pose_body=body_pose.reshape((B * seq_len, -1)),
|
| 35 |
+
pose_hand=None,
|
| 36 |
+
betas=betas.reshape((B * seq_len, -1)),
|
| 37 |
+
root_orient=root_orient.reshape((B * seq_len, -1)),
|
| 38 |
+
trans=trans.reshape((B * seq_len, -1)),
|
| 39 |
+
)
|
| 40 |
+
return {
|
| 41 |
+
"joints": smpl_body.Jtr.reshape(B, seq_len, -1, 3)[:, :T],
|
| 42 |
+
"vertices": smpl_body.v.reshape(B, seq_len, -1, 3)[:, :T],
|
| 43 |
+
"faces": smpl_body.f,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def zero_pad_tensors(pad_list, pad_size):
|
| 48 |
+
"""
|
| 49 |
+
Assumes tensors in pad_list are B x T x D and pad temporal dimension
|
| 50 |
+
"""
|
| 51 |
+
B = pad_list[0].size(0)
|
| 52 |
+
new_pad_list = []
|
| 53 |
+
for pad_idx, pad_tensor in enumerate(pad_list):
|
| 54 |
+
padding = torch.zeros((B, pad_size, pad_tensor.size(2))).to(pad_tensor)
|
| 55 |
+
new_pad_list.append(torch.cat([pad_tensor, padding], dim=1))
|
| 56 |
+
return new_pad_list
|
slahmr/slahmr/confs/config.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- data: posetrack
|
| 3 |
+
- optim
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
model:
|
| 7 |
+
floor_type: "shared"
|
| 8 |
+
est_floor: False
|
| 9 |
+
use_init: True
|
| 10 |
+
opt_cams: False
|
| 11 |
+
opt_scale: True
|
| 12 |
+
async_tracks: True
|
| 13 |
+
|
| 14 |
+
overwrite: False
|
| 15 |
+
run_opt: False
|
| 16 |
+
run_vis: False
|
| 17 |
+
vis:
|
| 18 |
+
phases:
|
| 19 |
+
- motion_chunks
|
| 20 |
+
- input
|
| 21 |
+
render_views:
|
| 22 |
+
- src_cam
|
| 23 |
+
- above
|
| 24 |
+
- side
|
| 25 |
+
make_grid: True
|
| 26 |
+
overwrite: False
|
| 27 |
+
|
| 28 |
+
paths:
|
| 29 |
+
smpl: _DATA/body_models/smplh/neutral/model.npz
|
| 30 |
+
smpl_kid: _DATA/body_models/smpl_kid_template.npy
|
| 31 |
+
vposer: _DATA/body_models/vposer_v1_0
|
| 32 |
+
init_motion_prior: _DATA/humor_ckpts/init_state_prior_gmm
|
| 33 |
+
humor: _DATA/humor_ckpts/humor/best_model.pth
|
| 34 |
+
|
| 35 |
+
humor:
|
| 36 |
+
in_rot_rep: "mat"
|
| 37 |
+
out_rot_rep: "aa"
|
| 38 |
+
latent_size: 48
|
| 39 |
+
model_data_config: "smpl+joints+contacts"
|
| 40 |
+
steps_in: 1
|
| 41 |
+
|
| 42 |
+
fps: 30
|
| 43 |
+
log_root: ../outputs/logs
|
| 44 |
+
log_dir: ${log_root}/${data.type}-${data.split}
|
| 45 |
+
exp_name: ${now:%Y-%m-%d}
|
| 46 |
+
|
| 47 |
+
hydra:
|
| 48 |
+
job:
|
| 49 |
+
chdir: True
|
| 50 |
+
run:
|
| 51 |
+
dir: ${log_dir}/${exp_name}/${data.name}
|
slahmr/slahmr/confs/data/3dpw.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: 3dpw
|
| 2 |
+
split: test
|
| 3 |
+
seq: downtown_arguing_00
|
| 4 |
+
root: /path/to/3DPW
|
| 5 |
+
use_cams: True
|
| 6 |
+
split_cameras: False
|
| 7 |
+
camera_name: cameras_intrins_split
|
| 8 |
+
shot_idx: 0
|
| 9 |
+
start_idx: 0
|
| 10 |
+
end_idx: 100
|
| 11 |
+
track_ids: "longest-2"
|
| 12 |
+
sources:
|
| 13 |
+
images: ${data.root}/imageFiles/${data.seq}
|
| 14 |
+
cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
|
| 15 |
+
intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
|
| 16 |
+
tracks: ${data.root}/slahmr/track_preds/${data.seq}
|
| 17 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
| 18 |
+
name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
|
slahmr/slahmr/confs/data/3dpw_gt.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: 3dpw_gt
|
| 2 |
+
split: test
|
| 3 |
+
seq: downtown_runForBus_00
|
| 4 |
+
root: /path/to/3DPW
|
| 5 |
+
use_cams: True
|
| 6 |
+
split_cameras: False
|
| 7 |
+
camera_name: cameras_intrins_split
|
| 8 |
+
shot_idx: 0
|
| 9 |
+
start_idx: 0
|
| 10 |
+
end_idx: 100
|
| 11 |
+
track_ids: "longest-2"
|
| 12 |
+
sources:
|
| 13 |
+
images: ${data.root}/imageFiles/${data.seq}
|
| 14 |
+
cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
|
| 15 |
+
intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
|
| 16 |
+
tracks: ${data.root}/slahmr/track_gt/${data.seq}
|
| 17 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
| 18 |
+
name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
|
slahmr/slahmr/confs/data/custom.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: custom
|
| 2 |
+
split: val
|
| 3 |
+
video: ""
|
| 4 |
+
seq: ""
|
| 5 |
+
root: /path/to/custom
|
| 6 |
+
use_cams: True
|
| 7 |
+
track_ids: "all"
|
| 8 |
+
shot_idx: 0
|
| 9 |
+
start_idx: 0
|
| 10 |
+
end_idx: 200
|
| 11 |
+
split_cameras: True
|
| 12 |
+
name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
|
| 13 |
+
sources:
|
| 14 |
+
images: ${data.root}/images/${data.seq}
|
| 15 |
+
cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
|
| 16 |
+
track: ${data.root}/slahmr/track_preds/${data.seq}
|
| 17 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
slahmr/slahmr/confs/data/davis.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: davis
|
| 2 |
+
split: all
|
| 3 |
+
seq: parkour
|
| 4 |
+
root: /path/to/DAVIS
|
| 5 |
+
use_cams: True
|
| 6 |
+
track_ids: "all"
|
| 7 |
+
shot_idx: 0
|
| 8 |
+
start_idx: 0
|
| 9 |
+
end_idx: -1
|
| 10 |
+
split_cameras: True
|
| 11 |
+
name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
|
| 12 |
+
sources:
|
| 13 |
+
images: ${data.root}/JPEGImages/Full-Resolution/${data.seq}
|
| 14 |
+
cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
|
| 15 |
+
tracks: ${data.root}/slahmr/track_preds/${data.seq}
|
| 16 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
slahmr/slahmr/confs/data/egobody.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: egobody
|
| 2 |
+
split: val
|
| 3 |
+
seq: recording_20210921_S11_S10_01
|
| 4 |
+
root: /path/to/egobody
|
| 5 |
+
use_cams: True
|
| 6 |
+
camera_name: cameras_intrins_split
|
| 7 |
+
shot_idx: 0
|
| 8 |
+
start_idx: 0
|
| 9 |
+
end_idx: 100
|
| 10 |
+
split_cameras: False
|
| 11 |
+
track_ids: "all"
|
| 12 |
+
sources:
|
| 13 |
+
images: ${data.root}/egocentric_color/${data.seq}/**/PV
|
| 14 |
+
cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
|
| 15 |
+
intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
|
| 16 |
+
tracks: ${data.root}/slahmr/track_preds/${data.seq}
|
| 17 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
| 18 |
+
name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
|
slahmr/slahmr/confs/data/posetrack.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: posetrack
|
| 2 |
+
split: val
|
| 3 |
+
seq: 014286_mpii_train
|
| 4 |
+
root: /path/to/posetrack
|
| 5 |
+
use_cams: True
|
| 6 |
+
track_ids: "all"
|
| 7 |
+
shot_idx: 0
|
| 8 |
+
start_idx: 0
|
| 9 |
+
end_idx: -1
|
| 10 |
+
split_cameras: True
|
| 11 |
+
name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
|
| 12 |
+
track_name: track_preds
|
| 13 |
+
sources:
|
| 14 |
+
images: ${data.root}/images/${data.split}/${data.seq}
|
| 15 |
+
cameras: ${data.root}/slahmr/${data.split}/cameras/${data.seq}/shot-${data.shot_idx}
|
| 16 |
+
tracks: ${data.root}/slahmr/${data.split}/${data.track_name}/${data.seq}
|
| 17 |
+
shots: ${data.root}/slahmr/${data.split}/shot_idcs/${data.seq}.json
|
slahmr/slahmr/confs/data/video.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: video
|
| 2 |
+
split: val
|
| 3 |
+
root: /path/to/data # put your videos in root/videos/vid.mp4
|
| 4 |
+
video_dir: videos
|
| 5 |
+
seq: basketball
|
| 6 |
+
ext: mp4
|
| 7 |
+
src_path: ${data.root}/${data.video_dir}/${data.seq}.${data.ext}
|
| 8 |
+
frame_opts:
|
| 9 |
+
ext: jpg
|
| 10 |
+
fps: 25
|
| 11 |
+
start_sec: 0
|
| 12 |
+
end_sec: -1
|
| 13 |
+
use_cams: True
|
| 14 |
+
track_ids: "all"
|
| 15 |
+
shot_idx: 0
|
| 16 |
+
start_idx: 0
|
| 17 |
+
end_idx: 180
|
| 18 |
+
split_cameras: True
|
| 19 |
+
name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}-${data.start_idx}-${data.end_idx}
|
| 20 |
+
sources:
|
| 21 |
+
images: ${data.root}/images/${data.seq}
|
| 22 |
+
cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
|
| 23 |
+
tracks: ${data.root}/slahmr/track_preds/${data.seq}
|
| 24 |
+
shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
|
slahmr/slahmr/confs/init.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- data: posetrack
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
gap: 1
|
| 6 |
+
log_root: outputs
|
| 7 |
+
save_per_frame: False
|
| 8 |
+
print_err: False
|
| 9 |
+
stride: 48
|
| 10 |
+
|
| 11 |
+
hydra:
|
| 12 |
+
run:
|
| 13 |
+
dir: ${log_root}/init/${data.type}-${data.seq}-${data.depth_dir}-${data.fov}fov-gap${gap}
|
slahmr/slahmr/confs/optim.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
optim:
|
| 2 |
+
options:
|
| 3 |
+
robust_loss_type: "bisquare"
|
| 4 |
+
robust_tuning_const: 4.6851
|
| 5 |
+
joints2d_sigma: 100.0
|
| 6 |
+
lr: 1.0
|
| 7 |
+
lbfgs_max_iter: 20
|
| 8 |
+
save_every: 20
|
| 9 |
+
vis_every: -1
|
| 10 |
+
max_chunk_steps: 20
|
| 11 |
+
save_meshes: False
|
| 12 |
+
|
| 13 |
+
root:
|
| 14 |
+
num_iters: 30
|
| 15 |
+
|
| 16 |
+
smpl:
|
| 17 |
+
num_iters: 0
|
| 18 |
+
|
| 19 |
+
smooth:
|
| 20 |
+
opt_scale: False
|
| 21 |
+
num_iters: 60
|
| 22 |
+
|
| 23 |
+
motion_chunks:
|
| 24 |
+
chunk_size: 10
|
| 25 |
+
init_steps: 20
|
| 26 |
+
chunk_steps: 20
|
| 27 |
+
opt_cams: True
|
| 28 |
+
|
| 29 |
+
loss_weights:
|
| 30 |
+
joints2d: [0.001, 0.001, 0.001]
|
| 31 |
+
bg2d: [0.0, 0.000, 0.000]
|
| 32 |
+
cam_R_smooth : [0.0, 0.0, 0.0]
|
| 33 |
+
cam_t_smooth : [0.0, 0.0, 0.0]
|
| 34 |
+
# bg2d: [0.0, 0.0001, 0.0001]
|
| 35 |
+
# cam_R_smooth : [0.0, 1000.0, 1000.0]
|
| 36 |
+
# cam_t_smooth : [0.0, 1000.0, 1000.0]
|
| 37 |
+
joints3d: [0.0, 0.0, 0.0]
|
| 38 |
+
joints3d_smooth: [1.0, 10.0, 0.0]
|
| 39 |
+
joints3d_rollout: [0.0, 0.0, 0.0]
|
| 40 |
+
verts3d: [0.0, 0.0, 0.0]
|
| 41 |
+
points3d: [0.0, 0.0, 0.0]
|
| 42 |
+
pose_prior: [0.04, 0.04, 0.04]
|
| 43 |
+
shape_prior: [0.05, 0.05, 0.05]
|
| 44 |
+
motion_prior: [0.0, 0.0, 0.075]
|
| 45 |
+
init_motion_prior: [0.0, 0.0, 0.075]
|
| 46 |
+
joint_consistency: [0.0, 0.0, 100.0]
|
| 47 |
+
bone_length: [0.0, 0.0, 2000.0]
|
| 48 |
+
contact_vel: [0.0, 0.0, 100.0]
|
| 49 |
+
contact_height: [0.0, 0.0, 10.0]
|
| 50 |
+
floor_reg: [0.0, 0.0, 0.0]
|
| 51 |
+
# floor_reg: [0.0, 0.0, 0.167]
|
slahmr/slahmr/data/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset import *
|
| 2 |
+
from . import tools
|
slahmr/slahmr/data/dataset.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import typing
|
| 4 |
+
|
| 5 |
+
import imageio
|
| 6 |
+
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from body_model import OP_NUM_JOINTS, SMPL_JOINTS
|
| 14 |
+
from util.logger import Logger
|
| 15 |
+
from geometry.camera import invert_camera
|
| 16 |
+
|
| 17 |
+
from .tools import read_keypoints, read_mask_path, load_smpl_preds
|
| 18 |
+
from .vidproc import preprocess_cameras, preprocess_frames, preprocess_tracks
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
Define data-related constants
|
| 23 |
+
"""
|
| 24 |
+
DEFAULT_GROUND = np.array([0.0, -1.0, 0.0, -0.5])
|
| 25 |
+
|
| 26 |
+
# XXX: TEMPORARY CONSTANTS
|
| 27 |
+
SHOT_PAD = 0
|
| 28 |
+
MIN_SEQ_LEN = 20
|
| 29 |
+
MAX_NUM_TRACKS = 12
|
| 30 |
+
MIN_TRACK_LEN = 20
|
| 31 |
+
MIN_KEYP_CONF = 0.4
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_dataset_from_cfg(cfg):
|
| 35 |
+
args = cfg.data
|
| 36 |
+
if not args.use_cams:
|
| 37 |
+
args.sources.cameras = ""
|
| 38 |
+
|
| 39 |
+
args.sources = expand_source_paths(args.sources)
|
| 40 |
+
print("DATA SOURCES", args.sources)
|
| 41 |
+
check_data_sources(args)
|
| 42 |
+
return MultiPeopleDataset(
|
| 43 |
+
args.sources,
|
| 44 |
+
args.seq,
|
| 45 |
+
tid_spec=args.track_ids,
|
| 46 |
+
shot_idx=args.shot_idx,
|
| 47 |
+
start_idx=int(args.start_idx),
|
| 48 |
+
end_idx=int(args.end_idx),
|
| 49 |
+
split_cameras=args.get("split_cameras", True),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def expand_source_paths(data_sources):
|
| 54 |
+
return {k: get_data_source(v) for k, v in data_sources.items()}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_data_source(source):
|
| 58 |
+
matches = glob.glob(source)
|
| 59 |
+
if len(matches) < 1:
|
| 60 |
+
print(f"{source} does not exist")
|
| 61 |
+
return source # return anyway for default values
|
| 62 |
+
if len(matches) > 1:
|
| 63 |
+
raise ValueError(f"{source} is not unique")
|
| 64 |
+
return matches[0]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def check_data_sources(args):
|
| 68 |
+
if args.type == "video":
|
| 69 |
+
preprocess_frames(args.sources.images, args.src_path, **args.frame_opts)
|
| 70 |
+
preprocess_tracks(args.sources.images, args.sources.tracks, args.sources.shots)
|
| 71 |
+
preprocess_cameras(args, overwrite=args.get("overwrite_cams", False))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MultiPeopleDataset(Dataset):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
data_sources: typing.Dict,
|
| 78 |
+
seq_name,
|
| 79 |
+
tid_spec="all",
|
| 80 |
+
shot_idx=0,
|
| 81 |
+
start_idx=0,
|
| 82 |
+
end_idx=-1,
|
| 83 |
+
pad_shot=False,
|
| 84 |
+
split_cameras=True,
|
| 85 |
+
):
|
| 86 |
+
self.seq_name = seq_name
|
| 87 |
+
self.data_sources = data_sources
|
| 88 |
+
self.split_cameras = split_cameras
|
| 89 |
+
|
| 90 |
+
# select only images in the desired shot
|
| 91 |
+
img_files, _ = get_shot_img_files(
|
| 92 |
+
self.data_sources["shots"], shot_idx, pad_shot
|
| 93 |
+
)
|
| 94 |
+
end_idx = end_idx if end_idx > 0 else len(img_files)
|
| 95 |
+
self.data_start, self.data_end = start_idx, end_idx
|
| 96 |
+
img_files = img_files[start_idx:end_idx]
|
| 97 |
+
self.img_names = [get_name(f) for f in img_files]
|
| 98 |
+
self.num_imgs = len(self.img_names)
|
| 99 |
+
|
| 100 |
+
img_dir = self.data_sources["images"]
|
| 101 |
+
assert os.path.isdir(img_dir)
|
| 102 |
+
self.img_paths = [os.path.join(img_dir, f) for f in img_files]
|
| 103 |
+
img_h, img_w = imageio.imread(self.img_paths[0]).shape[:2]
|
| 104 |
+
self.img_size = img_w, img_h
|
| 105 |
+
print(f"USING TOTAL {self.num_imgs} {img_w}x{img_h} IMGS")
|
| 106 |
+
|
| 107 |
+
# find the tracks in the video
|
| 108 |
+
track_root = self.data_sources["tracks"]
|
| 109 |
+
if tid_spec == "all" or tid_spec.startswith("longest"):
|
| 110 |
+
n_tracks = MAX_NUM_TRACKS
|
| 111 |
+
if tid_spec.startswith("longest"):
|
| 112 |
+
n_tracks = int(tid_spec.split("-")[1])
|
| 113 |
+
# get the longest tracks in the selected shot
|
| 114 |
+
track_ids = sorted(os.listdir(track_root))
|
| 115 |
+
track_paths = [
|
| 116 |
+
[f"{track_root}/{tid}/{name}_keypoints.json" for name in self.img_names]
|
| 117 |
+
for tid in track_ids
|
| 118 |
+
]
|
| 119 |
+
track_lens = [
|
| 120 |
+
len(list(filter(os.path.isfile, paths))) for paths in track_paths
|
| 121 |
+
]
|
| 122 |
+
track_ids = [
|
| 123 |
+
track_ids[i]
|
| 124 |
+
for i in np.argsort(track_lens)[::-1]
|
| 125 |
+
if track_lens[i] > MIN_TRACK_LEN
|
| 126 |
+
]
|
| 127 |
+
print("TRACK LENGTHS", track_ids, track_lens)
|
| 128 |
+
track_ids = track_ids[:n_tracks]
|
| 129 |
+
else:
|
| 130 |
+
track_ids = [f"{int(tid):03d}" for tid in tid_spec.split("-")]
|
| 131 |
+
|
| 132 |
+
print("TRACK IDS", track_ids)
|
| 133 |
+
|
| 134 |
+
self.track_ids = track_ids
|
| 135 |
+
self.n_tracks = len(track_ids)
|
| 136 |
+
self.track_dirs = [os.path.join(track_root, tid) for tid in track_ids]
|
| 137 |
+
|
| 138 |
+
# keep a list of frame index masks of whether a track is available in a frame
|
| 139 |
+
sidx = np.inf
|
| 140 |
+
eidx = -1
|
| 141 |
+
self.track_vis_masks = []
|
| 142 |
+
for pred_dir in self.track_dirs:
|
| 143 |
+
kp_paths = [f"{pred_dir}/{x}_keypoints.json" for x in self.img_names]
|
| 144 |
+
has_kp = [os.path.isfile(x) for x in kp_paths]
|
| 145 |
+
|
| 146 |
+
# keep track of which frames this track is visible in
|
| 147 |
+
vis_mask = np.array(has_kp)
|
| 148 |
+
idcs = np.where(vis_mask)[0]
|
| 149 |
+
if len(idcs) > 0:
|
| 150 |
+
si, ei = min(idcs), max(idcs)
|
| 151 |
+
sidx = min(sidx, si)
|
| 152 |
+
eidx = max(eidx, ei)
|
| 153 |
+
self.track_vis_masks.append(vis_mask)
|
| 154 |
+
|
| 155 |
+
eidx = max(eidx + 1, 0)
|
| 156 |
+
sidx = min(sidx, eidx)
|
| 157 |
+
print("START", sidx, "END", eidx)
|
| 158 |
+
self.start_idx = sidx
|
| 159 |
+
self.end_idx = eidx
|
| 160 |
+
self.seq_len = eidx - sidx
|
| 161 |
+
self.seq_intervals = [(sidx, eidx) for _ in track_ids]
|
| 162 |
+
|
| 163 |
+
self.sel_img_paths = self.img_paths[sidx:eidx]
|
| 164 |
+
self.sel_img_names = self.img_names[sidx:eidx]
|
| 165 |
+
|
| 166 |
+
# used to cache data
|
| 167 |
+
self.data_dict = {}
|
| 168 |
+
self.cam_data = None
|
| 169 |
+
|
| 170 |
+
def __len__(self):
|
| 171 |
+
return self.n_tracks
|
| 172 |
+
|
| 173 |
+
def load_data(self, interp_input=True):
|
| 174 |
+
if len(self.data_dict) > 0:
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
# load camera data
|
| 178 |
+
self.load_camera_data()
|
| 179 |
+
# get data for each track
|
| 180 |
+
data_out = {
|
| 181 |
+
"mask_paths": [],
|
| 182 |
+
"floor_plane": [],
|
| 183 |
+
"joints2d": [],
|
| 184 |
+
"vis_mask": [],
|
| 185 |
+
"track_interval": [],
|
| 186 |
+
"init_body_pose": [],
|
| 187 |
+
"init_root_orient": [],
|
| 188 |
+
"init_trans": [],
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# create batches of sequences
|
| 192 |
+
# each batch is a track for a person
|
| 193 |
+
T = self.seq_len
|
| 194 |
+
sidx, eidx = self.start_idx, self.end_idx
|
| 195 |
+
for i, tid in enumerate(self.track_ids):
|
| 196 |
+
# load mask of visible frames for this track
|
| 197 |
+
vis_mask = self.track_vis_masks[i][sidx:eidx] # (T)
|
| 198 |
+
vis_idcs = np.where(vis_mask)[0]
|
| 199 |
+
track_s, track_e = min(vis_idcs), max(vis_idcs) + 1
|
| 200 |
+
data_out["track_interval"].append([track_s, track_e])
|
| 201 |
+
|
| 202 |
+
vis_mask = get_ternary_mask(vis_mask)
|
| 203 |
+
data_out["vis_mask"].append(vis_mask)
|
| 204 |
+
|
| 205 |
+
# load 2d keypoints for visible frames
|
| 206 |
+
kp_paths = [
|
| 207 |
+
f"{self.track_dirs[i]}/{x}_keypoints.json" for x in self.sel_img_names
|
| 208 |
+
]
|
| 209 |
+
# (T, J, 3) (x, y, conf)
|
| 210 |
+
joints2d_data = np.stack(
|
| 211 |
+
[read_keypoints(p) for p in kp_paths], axis=0
|
| 212 |
+
).astype(np.float32)
|
| 213 |
+
# Discard bad ViTPose detections
|
| 214 |
+
joints2d_data[
|
| 215 |
+
np.repeat(joints2d_data[:, :, [2]] < MIN_KEYP_CONF, 3, axis=2)
|
| 216 |
+
] = 0
|
| 217 |
+
data_out["joints2d"].append(joints2d_data)
|
| 218 |
+
|
| 219 |
+
# load single image smpl predictions
|
| 220 |
+
pred_paths = [
|
| 221 |
+
f"{self.track_dirs[i]}/{x}_smpl.json" for x in self.sel_img_names
|
| 222 |
+
]
|
| 223 |
+
pose_init, orient_init, trans_init, _ = load_smpl_preds(
|
| 224 |
+
pred_paths, interp=interp_input
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
n_joints = len(SMPL_JOINTS) - 1
|
| 228 |
+
data_out["init_body_pose"].append(pose_init[:, :n_joints, :])
|
| 229 |
+
data_out["init_root_orient"].append(orient_init)
|
| 230 |
+
data_out["init_trans"].append(trans_init)
|
| 231 |
+
|
| 232 |
+
data_out["floor_plane"].append(DEFAULT_GROUND[:3] * DEFAULT_GROUND[3:])
|
| 233 |
+
|
| 234 |
+
self.data_dict = data_out
|
| 235 |
+
|
| 236 |
+
def __getitem__(self, idx):
|
| 237 |
+
if len(self.data_dict) < 1:
|
| 238 |
+
self.load_data()
|
| 239 |
+
|
| 240 |
+
obs_data = dict()
|
| 241 |
+
|
| 242 |
+
# 2D keypoints
|
| 243 |
+
joint2d_data = self.data_dict["joints2d"][idx]
|
| 244 |
+
obs_data["joints2d"] = torch.Tensor(joint2d_data)
|
| 245 |
+
|
| 246 |
+
# single frame predictions
|
| 247 |
+
obs_data["init_body_pose"] = torch.Tensor(self.data_dict["init_body_pose"][idx])
|
| 248 |
+
obs_data["init_root_orient"] = torch.Tensor(
|
| 249 |
+
self.data_dict["init_root_orient"][idx]
|
| 250 |
+
)
|
| 251 |
+
obs_data["init_trans"] = torch.Tensor(self.data_dict["init_trans"][idx])
|
| 252 |
+
|
| 253 |
+
# floor plane
|
| 254 |
+
obs_data["floor_plane"] = torch.Tensor(self.data_dict["floor_plane"][idx])
|
| 255 |
+
|
| 256 |
+
# the frames the track is visible in
|
| 257 |
+
obs_data["vis_mask"] = torch.Tensor(self.data_dict["vis_mask"][idx])
|
| 258 |
+
|
| 259 |
+
# the frames used in this subsequence
|
| 260 |
+
obs_data["seq_interval"] = torch.Tensor(list(self.seq_intervals[idx])).to(
|
| 261 |
+
torch.int
|
| 262 |
+
)
|
| 263 |
+
# the start and end interval of available keypoints
|
| 264 |
+
obs_data["track_interval"] = torch.Tensor(
|
| 265 |
+
self.data_dict["track_interval"][idx]
|
| 266 |
+
).int()
|
| 267 |
+
|
| 268 |
+
obs_data["track_id"] = int(self.track_ids[idx])
|
| 269 |
+
obs_data["seq_name"] = self.seq_name
|
| 270 |
+
return obs_data
|
| 271 |
+
|
| 272 |
+
def load_camera_data(self):
|
| 273 |
+
cam_dir = self.data_sources["cameras"]
|
| 274 |
+
data_interval = 0, -1
|
| 275 |
+
if self.split_cameras:
|
| 276 |
+
data_interval = self.data_start, self.data_end
|
| 277 |
+
track_interval = self.start_idx, self.end_idx
|
| 278 |
+
self.cam_data = CameraData(
|
| 279 |
+
cam_dir, self.seq_len, self.img_size, data_interval, track_interval
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def get_camera_data(self):
|
| 283 |
+
if self.cam_data is None:
|
| 284 |
+
raise ValueError
|
| 285 |
+
return self.cam_data.as_dict()
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class CameraData(object):
|
| 289 |
+
def __init__(
|
| 290 |
+
self, cam_dir, seq_len, img_size, data_interval=[0, -1], track_interval=[0, -1]
|
| 291 |
+
):
|
| 292 |
+
self.img_size = img_size
|
| 293 |
+
self.cam_dir = cam_dir
|
| 294 |
+
|
| 295 |
+
# inclusive exclusive
|
| 296 |
+
data_start, data_end = data_interval
|
| 297 |
+
if data_end < 0:
|
| 298 |
+
data_end += seq_len + 1
|
| 299 |
+
data_len = data_end - data_start
|
| 300 |
+
|
| 301 |
+
# start and end indices are with respect to the data interval
|
| 302 |
+
sidx, eidx = track_interval
|
| 303 |
+
if eidx < 0:
|
| 304 |
+
eidx += data_len + 1
|
| 305 |
+
self.sidx, self.eidx = sidx + data_start, eidx + data_start
|
| 306 |
+
self.seq_len = self.eidx - self.sidx
|
| 307 |
+
|
| 308 |
+
self.load_data()
|
| 309 |
+
|
| 310 |
+
def load_data(self):
|
| 311 |
+
# camera info
|
| 312 |
+
sidx, eidx = self.sidx, self.eidx
|
| 313 |
+
img_w, img_h = self.img_size
|
| 314 |
+
fpath = os.path.join(self.cam_dir, "cameras.npz")
|
| 315 |
+
if os.path.isfile(fpath):
|
| 316 |
+
Logger.log(f"Loading cameras from {fpath}...")
|
| 317 |
+
cam_R, cam_t, intrins, width, height = load_cameras_npz(fpath)
|
| 318 |
+
scale = img_w / width
|
| 319 |
+
self.intrins = scale * intrins[sidx:eidx]
|
| 320 |
+
# move first camera to origin
|
| 321 |
+
# R0, t0 = invert_camera(cam_R[sidx], cam_t[sidx])
|
| 322 |
+
# self.cam_R = torch.einsum("ij,...jk->...ik", R0, cam_R[sidx:eidx])
|
| 323 |
+
# self.cam_t = t0 + torch.einsum("ij,...j->...i", R0, cam_t[sidx:eidx])
|
| 324 |
+
# t0 = -cam_t[sidx:eidx].mean(dim=0) + torch.randn(3) * 0.1
|
| 325 |
+
t0 = -cam_t[sidx:sidx+1] + torch.randn(3) * 0.1
|
| 326 |
+
self.cam_R = cam_R[sidx:eidx]
|
| 327 |
+
self.cam_t = cam_t[sidx:eidx] - t0
|
| 328 |
+
self.is_static = False
|
| 329 |
+
else:
|
| 330 |
+
Logger.log(f"WARNING: {fpath} does not exist, using static cameras...")
|
| 331 |
+
default_focal = 0.5 * (img_h + img_w)
|
| 332 |
+
self.intrins = torch.tensor(
|
| 333 |
+
[default_focal, default_focal, img_w / 2, img_h / 2]
|
| 334 |
+
)[None].repeat(self.seq_len, 1)
|
| 335 |
+
|
| 336 |
+
self.cam_R = torch.eye(3)[None].repeat(self.seq_len, 1, 1)
|
| 337 |
+
self.cam_t = torch.zeros(self.seq_len, 3)
|
| 338 |
+
self.is_static = True
|
| 339 |
+
|
| 340 |
+
Logger.log(f"Images have {img_w}x{img_h}, intrins {self.intrins[0]}")
|
| 341 |
+
print("CAMERA DATA", self.cam_R.shape, self.cam_t.shape, self.intrins[0])
|
| 342 |
+
|
| 343 |
+
def world2cam(self):
|
| 344 |
+
return self.cam_R, self.cam_t
|
| 345 |
+
|
| 346 |
+
def cam2world(self):
|
| 347 |
+
R = self.cam_R.transpose(-1, -2)
|
| 348 |
+
t = -torch.einsum("bij,bj->bi", R, self.cam_t)
|
| 349 |
+
return R, t
|
| 350 |
+
|
| 351 |
+
def as_dict(self):
|
| 352 |
+
return {
|
| 353 |
+
"cam_R": self.cam_R, # (T, 3, 3)
|
| 354 |
+
"cam_t": self.cam_t, # (T, 3)
|
| 355 |
+
"intrins": self.intrins, # (T, 4)
|
| 356 |
+
"static": self.is_static, # bool
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_ternary_mask(vis_mask):
|
| 361 |
+
# get the track start and end idcs relative to the filtered interval
|
| 362 |
+
vis_mask = torch.as_tensor(vis_mask)
|
| 363 |
+
vis_idcs = torch.where(vis_mask)[0]
|
| 364 |
+
track_s, track_e = min(vis_idcs), max(vis_idcs) + 1
|
| 365 |
+
# -1 = track out of scene, 0 = occlusion, 1 = visible
|
| 366 |
+
vis_mask = vis_mask.float()
|
| 367 |
+
vis_mask[:track_s] = -1
|
| 368 |
+
vis_mask[track_e:] = -1
|
| 369 |
+
return vis_mask
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_shot_img_files(shots_path, shot_idx, shot_pad=SHOT_PAD):
|
| 373 |
+
assert os.path.isfile(shots_path)
|
| 374 |
+
with open(shots_path, "r") as f:
|
| 375 |
+
shots_dict = json.load(f)
|
| 376 |
+
img_names = sorted(shots_dict.keys())
|
| 377 |
+
N = len(img_names)
|
| 378 |
+
shot_mask = np.array([shots_dict[x] == shot_idx for x in img_names])
|
| 379 |
+
|
| 380 |
+
idcs = np.where(shot_mask)[0]
|
| 381 |
+
if shot_pad > 0: # drop the frames before/after shot change
|
| 382 |
+
if min(idcs) > 0:
|
| 383 |
+
idcs = idcs[shot_pad:]
|
| 384 |
+
if len(idcs) > 0 and max(idcs) < N - 1:
|
| 385 |
+
idcs = idcs[:-shot_pad]
|
| 386 |
+
if len(idcs) < MIN_SEQ_LEN:
|
| 387 |
+
raise ValueError("shot is too short for optimization")
|
| 388 |
+
|
| 389 |
+
shot_mask = np.zeros(N, dtype=bool)
|
| 390 |
+
shot_mask[idcs] = 1
|
| 391 |
+
sel_paths = [img_names[i] for i in idcs]
|
| 392 |
+
print(f"FOUND {len(idcs)}/{len(shots_dict)} FRAMES FOR SHOT {shot_idx}")
|
| 393 |
+
return sel_paths, idcs
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def load_cameras_npz(camera_path):
|
| 397 |
+
assert os.path.splitext(camera_path)[-1] == ".npz"
|
| 398 |
+
|
| 399 |
+
cam_data = np.load(camera_path)
|
| 400 |
+
height, width, focal = (
|
| 401 |
+
int(cam_data["height"]),
|
| 402 |
+
int(cam_data["width"]),
|
| 403 |
+
float(cam_data["focal"]),
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
w2c = torch.from_numpy(cam_data["w2c"]) # (N, 4, 4)
|
| 407 |
+
cam_R = w2c[:, :3, :3] # (N, 3, 3)
|
| 408 |
+
cam_t = w2c[:, :3, 3] # (N, 3)
|
| 409 |
+
N = len(w2c)
|
| 410 |
+
|
| 411 |
+
if "intrins" in cam_data:
|
| 412 |
+
intrins = torch.from_numpy(cam_data["intrins"].astype(np.float32))
|
| 413 |
+
else:
|
| 414 |
+
intrins = torch.tensor([focal, focal, width / 2, height / 2])[None].repeat(N, 1)
|
| 415 |
+
|
| 416 |
+
print(f"Loaded {N} cameras")
|
| 417 |
+
return cam_R, cam_t, intrins, width, height
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def is_image(x):
|
| 421 |
+
return (x.endswith(".png") or x.endswith(".jpg")) and not x.startswith(".")
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def get_name(x):
|
| 425 |
+
return os.path.splitext(os.path.basename(x))[0]
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def split_name(x, suffix):
|
| 429 |
+
return os.path.basename(x).split(suffix)[0]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def get_names_in_dir(d, suffix):
|
| 433 |
+
files = [split_name(x, suffix) for x in glob.glob(f"{d}/*{suffix}")]
|
| 434 |
+
return sorted(files)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def batch_join(parent, names, suffix=""):
|
| 438 |
+
return [os.path.join(parent, f"{n}{suffix}") for n in names]
|
slahmr/slahmr/data/tools.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import functools
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from body_model import OP_NUM_JOINTS
|
| 8 |
+
from scipy.interpolate import interp1d
|
| 9 |
+
from scipy.spatial.transform import Rotation, Slerp
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def read_keypoints(keypoint_fn):
|
| 13 |
+
"""
|
| 14 |
+
Only reads body keypoint data of first person.
|
| 15 |
+
"""
|
| 16 |
+
empty_kps = np.zeros((OP_NUM_JOINTS, 3), dtype=np.float)
|
| 17 |
+
if not os.path.isfile(keypoint_fn):
|
| 18 |
+
return empty_kps
|
| 19 |
+
|
| 20 |
+
with open(keypoint_fn) as keypoint_file:
|
| 21 |
+
data = json.load(keypoint_file)
|
| 22 |
+
|
| 23 |
+
if len(data["people"]) == 0:
|
| 24 |
+
print("WARNING: Found no keypoints in %s! Returning zeros!" % (keypoint_fn))
|
| 25 |
+
return empty_kps
|
| 26 |
+
|
| 27 |
+
person_data = data["people"][0]
|
| 28 |
+
body_keypoints = np.array(person_data["pose_keypoints_2d"], dtype=np.float)
|
| 29 |
+
body_keypoints = body_keypoints.reshape([-1, 3])
|
| 30 |
+
return body_keypoints
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def read_mask_path(path):
|
| 34 |
+
mask_path = None
|
| 35 |
+
if not os.path.isfile(path):
|
| 36 |
+
return mask_path
|
| 37 |
+
|
| 38 |
+
with open(path, "r") as f:
|
| 39 |
+
data = json.load(path)
|
| 40 |
+
|
| 41 |
+
person_data = data["people"][0]
|
| 42 |
+
if "mask_path" in person_data:
|
| 43 |
+
mask_path = person_data["mask_path"]
|
| 44 |
+
|
| 45 |
+
return mask_path
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def read_smpl_preds(pred_path, num_betas=10):
|
| 49 |
+
"""
|
| 50 |
+
reads the betas, body_pose, global orientation and translation of a smpl prediction
|
| 51 |
+
exported from phalp outputs
|
| 52 |
+
returns betas (10,), body_pose (23, 3), global_orientation (3,), translation (3,)
|
| 53 |
+
"""
|
| 54 |
+
pose = np.zeros((23, 3))
|
| 55 |
+
rot = np.zeros(3)
|
| 56 |
+
trans = np.zeros(3)
|
| 57 |
+
betas = np.zeros(num_betas)
|
| 58 |
+
if not os.path.isfile(pred_path):
|
| 59 |
+
return pose, rot, trans, betas
|
| 60 |
+
|
| 61 |
+
with open(pred_path, "r") as f:
|
| 62 |
+
data = json.load(f)
|
| 63 |
+
|
| 64 |
+
if "body_pose" in data:
|
| 65 |
+
pose = np.array(data["body_pose"], dtype=np.float32)
|
| 66 |
+
|
| 67 |
+
if "global_orient" in data:
|
| 68 |
+
rot = np.array(data["global_orient"], dtype=np.float32)
|
| 69 |
+
|
| 70 |
+
if "cam_trans" in data:
|
| 71 |
+
trans = np.array(data["cam_trans"], dtype=np.float32)
|
| 72 |
+
|
| 73 |
+
if "betas" in data:
|
| 74 |
+
betas = np.array(data["betas"], dtype=np.float32)
|
| 75 |
+
|
| 76 |
+
return pose, rot, trans, betas
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_smpl_preds(pred_paths, interp=True, num_betas=10):
|
| 80 |
+
vis_mask = np.array([os.path.isfile(x) for x in pred_paths])
|
| 81 |
+
vis_idcs = np.where(vis_mask)[0]
|
| 82 |
+
|
| 83 |
+
# load single image smpl predictions
|
| 84 |
+
stack_fnc = functools.partial(np.stack, axis=0)
|
| 85 |
+
# (N, 23, 3), (N, 3), (N, 3), (N, 10)
|
| 86 |
+
pose, orient, trans, betas = map(
|
| 87 |
+
stack_fnc, zip(*[read_smpl_preds(p, num_betas=num_betas) for p in pred_paths])
|
| 88 |
+
)
|
| 89 |
+
if not interp:
|
| 90 |
+
return pose, orient, trans, betas
|
| 91 |
+
|
| 92 |
+
# interpolate the occluded tracks
|
| 93 |
+
orient_slerp = Slerp(vis_idcs, Rotation.from_rotvec(orient[vis_idcs]))
|
| 94 |
+
trans_interp = interp1d(vis_idcs, trans[vis_idcs], axis=0)
|
| 95 |
+
betas_interp = interp1d(vis_idcs, betas[vis_idcs], axis=0)
|
| 96 |
+
|
| 97 |
+
tmin, tmax = min(vis_idcs), max(vis_idcs) + 1
|
| 98 |
+
times = np.arange(tmin, tmax)
|
| 99 |
+
orient[times] = orient_slerp(times).as_rotvec()
|
| 100 |
+
trans[times] = trans_interp(times)
|
| 101 |
+
betas[times] = betas_interp(times)
|
| 102 |
+
|
| 103 |
+
# interpolate for each joint angle
|
| 104 |
+
for i in range(pose.shape[1]):
|
| 105 |
+
pose_slerp = Slerp(vis_idcs, Rotation.from_rotvec(pose[vis_idcs, i]))
|
| 106 |
+
pose[times, i] = pose_slerp(times).as_rotvec()
|
| 107 |
+
|
| 108 |
+
return pose, orient, trans, betas
|
slahmr/slahmr/data/vidproc.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
import preproc.launch_phalp as phalp
|
| 6 |
+
from preproc.launch_slam import split_frames_shots, get_command, check_intrins
|
| 7 |
+
from preproc.extract_frames import video_to_frames
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def is_nonempty(d):
|
| 11 |
+
return os.path.isdir(d) and len(os.listdir(d)) > 0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def preprocess_frames(img_dir, src_path, overwrite=False, **kwargs):
|
| 15 |
+
if not overwrite and is_nonempty(img_dir):
|
| 16 |
+
print(f"FOUND {len(os.listdir(img_dir))} FRAMES in {img_dir}")
|
| 17 |
+
return
|
| 18 |
+
print(f"EXTRACTING FRAMES FROM {src_path} TO {img_dir}")
|
| 19 |
+
print(kwargs)
|
| 20 |
+
out = video_to_frames(src_path, img_dir, overwrite=overwrite, **kwargs)
|
| 21 |
+
assert out == 0, "FAILED FRAME EXTRACTION"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def preprocess_tracks(img_dir, track_dir, shot_dir, overwrite=False):
|
| 25 |
+
"""
|
| 26 |
+
:param img_dir
|
| 27 |
+
:param track_dir, expected format: res_root/track_name/sequence
|
| 28 |
+
:param shot_dir, expected format: res_root/shot_name/sequence
|
| 29 |
+
"""
|
| 30 |
+
if not overwrite and is_nonempty(track_dir):
|
| 31 |
+
print(f"FOUND TRACKS IN {track_dir}")
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
print(f"RUNNING PHALP ON {img_dir}")
|
| 35 |
+
track_root, seq = os.path.split(track_dir.rstrip("/"))
|
| 36 |
+
res_root, track_name = os.path.split(track_root)
|
| 37 |
+
shot_name = shot_dir.rstrip("/").split("/")[-2]
|
| 38 |
+
gpu = os.environ.get("CUDA_VISIBLE_DEVICES", 0)
|
| 39 |
+
|
| 40 |
+
phalp.process_seq(
|
| 41 |
+
[gpu],
|
| 42 |
+
seq,
|
| 43 |
+
img_dir,
|
| 44 |
+
f"{res_root}/phalp_out",
|
| 45 |
+
track_name=track_name,
|
| 46 |
+
shot_name=shot_name,
|
| 47 |
+
overwrite=overwrite,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def preprocess_cameras(cfg, overwrite=False):
|
| 52 |
+
if not overwrite and is_nonempty(cfg.sources.cameras):
|
| 53 |
+
print(f"FOUND CAMERAS IN {cfg.sources.cameras}")
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
print(f"RUNNING SLAM ON {cfg.seq}")
|
| 57 |
+
img_dir = cfg.sources.images
|
| 58 |
+
map_dir = cfg.sources.cameras
|
| 59 |
+
subseqs, shot_idcs = split_frames_shots(cfg.sources.images, cfg.sources.shots)
|
| 60 |
+
shot_idx = np.where(shot_idcs == cfg.shot_idx)[0][0]
|
| 61 |
+
# run on selected shot
|
| 62 |
+
start, end = subseqs[shot_idx]
|
| 63 |
+
if not cfg.split_cameras:
|
| 64 |
+
# only run on specified segment within shot
|
| 65 |
+
end = start + cfg.end_idx
|
| 66 |
+
start = start + cfg.start_idx
|
| 67 |
+
intrins_path = cfg.sources.get("intrins", None)
|
| 68 |
+
if intrins_path is not None:
|
| 69 |
+
intrins_path = check_intrins(cfg.type, cfg.root, intrins_path, cfg.seq, cfg.split)
|
| 70 |
+
|
| 71 |
+
cmd = get_command(
|
| 72 |
+
img_dir,
|
| 73 |
+
map_dir,
|
| 74 |
+
start=start,
|
| 75 |
+
end=end,
|
| 76 |
+
intrins_path=intrins_path,
|
| 77 |
+
overwrite=overwrite,
|
| 78 |
+
)
|
| 79 |
+
print(cmd)
|
| 80 |
+
gpu = os.environ.get("CUDA_VISIBLE_DEVICES", 0)
|
| 81 |
+
out = subprocess.call(f"CUDA_VISIBLE_DEVICES={gpu} {cmd}", shell=True)
|
| 82 |
+
assert out == 0, "SLAM FAILED"
|
slahmr/slahmr/eval/__init__.py
ADDED
|
File without changes
|
slahmr/slahmr/eval/associate.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import joblib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from data.tools import read_keypoints
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def associate_phalp_track_dirs(
|
| 12 |
+
phalp_dir, img_dir, track_ids, gt_kps, start=0, end=-1, debug=False
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Associate the M track_ids with G GT tracks
|
| 16 |
+
returns (M, T) array of best matching GT person index
|
| 17 |
+
:param phalp_dir (str) directory with phalp track folders
|
| 18 |
+
:param img_dir (str) directory with source images
|
| 19 |
+
:param track ids (list) M tracks to match
|
| 20 |
+
:param gt_kps (G, T, J, 3) gt keypoints for G people, T times, J joints
|
| 21 |
+
:param start (optional int default 0)
|
| 22 |
+
:param end (optional int default -1)
|
| 23 |
+
"""
|
| 24 |
+
img_names = sorted([os.path.splitext(x)[0] for x in os.listdir(img_dir)])
|
| 25 |
+
N = len(img_names)
|
| 26 |
+
end = N + 1 + end if end < 0 else end
|
| 27 |
+
sel_imgs = img_names[start:end]
|
| 28 |
+
G, T = gt_kps.shape[:2] # G num people, T num frames
|
| 29 |
+
assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
|
| 30 |
+
|
| 31 |
+
track_ids = [f"{int(tid):03d}" for tid in track_ids]
|
| 32 |
+
M = len(track_ids)
|
| 33 |
+
# find the best matching GT track for each PHALP track
|
| 34 |
+
match_idcs = torch.full((M, T), -1)
|
| 35 |
+
for t, frame_name in enumerate(sel_imgs):
|
| 36 |
+
track_kps = [] # get track keypoints
|
| 37 |
+
for tid in track_ids:
|
| 38 |
+
kp_path = f"{phalp_dir}/{tid}/{frame_name}_keypoints.json"
|
| 39 |
+
track_kps.append(read_keypoints(kp_path))
|
| 40 |
+
track_kps = np.stack(track_kps, axis=0) # (M, 25, 3)
|
| 41 |
+
for g in range(G):
|
| 42 |
+
kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
|
| 43 |
+
m = associate_keypoints(kp_gt, track_kps, debug=debug)
|
| 44 |
+
if m == -1:
|
| 45 |
+
continue
|
| 46 |
+
match_idcs[m, t] = g
|
| 47 |
+
return match_idcs
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def associate_phalp_track_data(
|
| 51 |
+
phalp_file, track_ids, gt_kps, start=0, end=-1, debug=False
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Get the best GT person for each phalp track
|
| 55 |
+
:param phalp_file (path) to phalp result pickle file
|
| 56 |
+
:param gt_kps (G, T, 3, 18) gt keypoints
|
| 57 |
+
:param track_ids (list) of phalp track ids
|
| 58 |
+
:param start (optional int)
|
| 59 |
+
:param end (optional int)
|
| 60 |
+
return (M, T) array the matching GT person index for each phalp track
|
| 61 |
+
"""
|
| 62 |
+
data = joblib.load(phalp_file)
|
| 63 |
+
img_names = sorted(data.keys())
|
| 64 |
+
N = len(img_names) # number of frames
|
| 65 |
+
end = N + 1 + end if end < 0 else end
|
| 66 |
+
sel_imgs = img_names[start:end]
|
| 67 |
+
|
| 68 |
+
G, T = gt_kps.shape[:2] # G num people, T num frames
|
| 69 |
+
assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
|
| 70 |
+
|
| 71 |
+
M = len(track_ids)
|
| 72 |
+
track_idcs = {tid: m for m, tid in enumerate(track_ids)}
|
| 73 |
+
# get the best matching GT track for each PHALP track
|
| 74 |
+
match_idcs = torch.full((M, T), -1)
|
| 75 |
+
for t, frame_name in enumerate(sel_imgs):
|
| 76 |
+
frame_data = data[frame_name]
|
| 77 |
+
for g in range(G):
|
| 78 |
+
kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
|
| 79 |
+
# get the best track ID for the GT person
|
| 80 |
+
tid = associate_frame_dict(frame_data, kp_gt, track_ids, debug=debug)
|
| 81 |
+
if tid == -1:
|
| 82 |
+
continue
|
| 83 |
+
m = track_idcs[tid]
|
| 84 |
+
match_idcs[m, t] = g
|
| 85 |
+
return match_idcs
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def associate_keypoints(gt_kps, track_kps, debug=False):
|
| 89 |
+
"""
|
| 90 |
+
:param gt_bbox (25, 3)
|
| 91 |
+
:param track_bboxes (M, 25, 3)
|
| 92 |
+
return the index of the best overlapping track bbox
|
| 93 |
+
"""
|
| 94 |
+
gt_kps = gt_kps[gt_kps[:, 2] > 0, :2]
|
| 95 |
+
if len(gt_kps) < 1:
|
| 96 |
+
return -1
|
| 97 |
+
bb_min, bb_max = gt_kps.min(axis=0), gt_kps.max(axis=0)
|
| 98 |
+
gt_bbox = np.concatenate([bb_min, bb_max], axis=-1) # (4,)
|
| 99 |
+
|
| 100 |
+
track_kps = track_kps[..., :2] # (M, 25, 2)
|
| 101 |
+
track_min, track_max = track_kps.min(axis=1), track_kps.max(axis=1)
|
| 102 |
+
track_bboxes = np.concatenate([track_min, track_max], axis=-1) # (M, 4)
|
| 103 |
+
|
| 104 |
+
ious = np.stack([compute_iou(bb, gt_bbox)[0] for bb in track_bboxes], axis=0)
|
| 105 |
+
return np.argmax(ious)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def associate_frame_dict(frame_data, gt_kps, track_ids, debug=False):
|
| 109 |
+
"""
|
| 110 |
+
For the GT keypoints, find the PHALP track in track_ids with best overlap
|
| 111 |
+
:param frame_data (dict) PHALP output data
|
| 112 |
+
:param gt_kps (25, 3)
|
| 113 |
+
:param track_ids (list of N) PHALP track ids to search over
|
| 114 |
+
return the id in track_ids with the biggest overlap with gt_kps
|
| 115 |
+
"""
|
| 116 |
+
gt_kps = gt_kps[gt_kps[:, 2] > 0, :2]
|
| 117 |
+
if len(gt_kps) < 1:
|
| 118 |
+
return -1
|
| 119 |
+
bb_min, bb_max = gt_kps.min(axis=0), gt_kps.max(axis=0)
|
| 120 |
+
gt_bbox = np.concatenate([bb_min, bb_max], axis=-1) # (4,)
|
| 121 |
+
|
| 122 |
+
# use strs for track ids
|
| 123 |
+
tid_strs = [str(tid) for tid in track_ids]
|
| 124 |
+
# get the list indices of the PHALP tracks
|
| 125 |
+
track_idcs = {
|
| 126 |
+
str(int(tid)): i
|
| 127 |
+
for i, tid in enumerate(frame_data["tid"])
|
| 128 |
+
if tid in frame_data["tracked_ids"]
|
| 129 |
+
}
|
| 130 |
+
# select the track with the biggest overlap with the gt kps
|
| 131 |
+
ious = []
|
| 132 |
+
for tid in track_ids:
|
| 133 |
+
if tid not in track_idcs:
|
| 134 |
+
ious.append(0)
|
| 135 |
+
continue
|
| 136 |
+
bb = frame_data["bbox"][track_idcs[tid]] # (min_x, min_y, w, h)
|
| 137 |
+
bbox = np.concatenate([bb[:2], bb[:2] + bb[2:]], axis=-1)
|
| 138 |
+
iou = compute_iou(bbox, gt_bbox)[0]
|
| 139 |
+
ious.append(iou)
|
| 140 |
+
ious = np.stack(ious, axis=0)
|
| 141 |
+
idx = np.argmax(ious)
|
| 142 |
+
if debug:
|
| 143 |
+
print(track_ids[idx], track_ids, ious)
|
| 144 |
+
return track_ids[idx]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def compute_iou(bb1, bb2):
|
| 148 |
+
"""
|
| 149 |
+
:param bb1 (..., 4) top left x, y bottom right x y
|
| 150 |
+
:param bb2 (..., 4) top left x, y bottom right x y
|
| 151 |
+
return (...) IOU
|
| 152 |
+
"""
|
| 153 |
+
x11, y11, x12, y12 = np.split(bb1, 4, axis=-1)
|
| 154 |
+
x21, y21, x22, y22 = np.split(bb2, 4, axis=-1)
|
| 155 |
+
x1 = np.maximum(x11, x21)
|
| 156 |
+
y1 = np.maximum(y11, y21)
|
| 157 |
+
x2 = np.minimum(x12, x22)
|
| 158 |
+
y2 = np.minimum(y12, y22)
|
| 159 |
+
intersect = np.maximum((x2 - x1) * (y2 - y1), 0)
|
| 160 |
+
union = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - intersect
|
| 161 |
+
return intersect / (union + 1e-6)
|
slahmr/slahmr/eval/egobody_utils.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import itertools
|
| 3 |
+
import glob
|
| 4 |
+
import pickle
|
| 5 |
+
import json
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from tools import load_body_model, move_to, detach_all, EGOBODY_ROOT
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_sequence_body_info(seq_name):
|
| 15 |
+
info_file = f"{EGOBODY_ROOT}/data_info_release.csv"
|
| 16 |
+
info_df = pd.read_csv(info_file)
|
| 17 |
+
seq_info = info_df[info_df["recording_name"] == seq_name]
|
| 18 |
+
return seq_info["body_idx_fpv"].values[0]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_egobody_split(split):
|
| 22 |
+
split_file = f"{EGOBODY_ROOT}/data_splits.csv"
|
| 23 |
+
split_df = pd.read_csv(split_file)
|
| 24 |
+
if split not in split_df.columns:
|
| 25 |
+
print(f"{split} not in {split_file}")
|
| 26 |
+
return []
|
| 27 |
+
return split_df[split].dropna().tolist()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_egobody_seq_paths(seq_name, start=0, end=-1):
|
| 31 |
+
img_dir = get_egobody_img_dir(seq_name)
|
| 32 |
+
# img files are named [timestamp]_frame_[index].jpg
|
| 33 |
+
img_files = sorted(os.listdir(img_dir))
|
| 34 |
+
end = len(img_files) if end < 0 else end
|
| 35 |
+
print(f"FOUND {len(img_files)} FILES FOR SEQ {seq_name}")
|
| 36 |
+
return img_files[start:end]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_egobody_seq_names(seq_name, start=0, end=-1):
|
| 40 |
+
img_files = get_egobody_seq_paths(seq_name, start=start, end=end)
|
| 41 |
+
frame_names = ["_".join(x.split(".")[0].split("_")[1:]) for x in img_files]
|
| 42 |
+
return frame_names
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_egobody_img_dir(seq_name):
|
| 46 |
+
img_dir = f"{EGOBODY_ROOT}/egocentric_color/{seq_name}/**/PV"
|
| 47 |
+
matches = glob.glob(img_dir)
|
| 48 |
+
if len(matches) != 1:
|
| 49 |
+
raise ValueError(f"{img_dir} has {len(matches)} matches!")
|
| 50 |
+
return matches[0]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_egobody_keypoints(seq_name, start=0, end=-1):
|
| 54 |
+
img_dir = os.path.dirname(get_egobody_img_dir(seq_name))
|
| 55 |
+
kp_file = f"{img_dir}/keypoints.npz"
|
| 56 |
+
valid_file = f"{img_dir}/valid_frame.npz"
|
| 57 |
+
|
| 58 |
+
# missing keypoints aren't included, must fill in
|
| 59 |
+
kp_dict = {}
|
| 60 |
+
valid_dict = {}
|
| 61 |
+
kp_data = np.load(kp_file)
|
| 62 |
+
valid_data = np.load(valid_file)
|
| 63 |
+
|
| 64 |
+
zeros = np.zeros_like(kp_data["keypoints"][0])
|
| 65 |
+
for img_path, kps in zip(kp_data["imgname"], kp_data["keypoints"]):
|
| 66 |
+
img_name = os.path.basename(img_path)
|
| 67 |
+
kp_dict[img_name] = kps
|
| 68 |
+
|
| 69 |
+
for img_path, valid in zip(valid_data["imgname"], valid_data["valid"]):
|
| 70 |
+
img_name = os.path.basename(img_path)
|
| 71 |
+
valid_dict[img_name] = valid
|
| 72 |
+
|
| 73 |
+
img_paths = sorted(glob.glob(f"{img_dir}/PV/*.jpg"))
|
| 74 |
+
end = len(img_paths) + 1 + end if end < 0 else end
|
| 75 |
+
img_names = [os.path.basename(x) for x in img_paths[start:end]]
|
| 76 |
+
kps = np.stack([kp_dict.get(name, zeros) for name in img_names], axis=0)
|
| 77 |
+
valid = np.stack([valid_dict.get(name, False) for name in img_names], axis=0)
|
| 78 |
+
return kps, valid
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def load_egobody_smpl_params(seq_name, start=0, end=-1):
|
| 82 |
+
frame_names = get_egobody_seq_names(seq_name, start=start, end=end)
|
| 83 |
+
body_name = get_sequence_body_info(seq_name)
|
| 84 |
+
body_idx, gender = body_name.split(" ")
|
| 85 |
+
smpl_dir = (
|
| 86 |
+
f"{EGOBODY_ROOT}/smpl_interactee_val/{seq_name}/body_idx_{body_idx}/results"
|
| 87 |
+
)
|
| 88 |
+
if not os.path.isdir(smpl_dir):
|
| 89 |
+
raise ValueError(f"EXPECTED BODY DIR {smpl_dir} DOES NOT EXIST")
|
| 90 |
+
|
| 91 |
+
print(f"LOADING {len(frame_names)} SMPL PARAMS FROM {smpl_dir}")
|
| 92 |
+
smpl_dict = {"trans": [], "root_orient": [], "pose_body": [], "betas": []}
|
| 93 |
+
for frame in frame_names:
|
| 94 |
+
with open(f"{smpl_dir}/{frame}/000.pkl", "rb") as f:
|
| 95 |
+
# data has global_orient, body_pose, betas, transl
|
| 96 |
+
data = pickle.load(f)
|
| 97 |
+
smpl_dict["trans"].append(torch.from_numpy(data["transl"]))
|
| 98 |
+
smpl_dict["pose_body"].append(torch.from_numpy(data["body_pose"]))
|
| 99 |
+
smpl_dict["root_orient"].append(torch.from_numpy(data["global_orient"]))
|
| 100 |
+
smpl_dict["betas"].append(torch.from_numpy(data["betas"]))
|
| 101 |
+
smpl_dict = {k: torch.cat(v, dim=0)[None] for k, v in smpl_dict.items()}
|
| 102 |
+
smpl_dict["genders"] = [gender]
|
| 103 |
+
return smpl_dict
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def load_egobody_intrinsics(seq_name, start=0, end=-1, ret_size_tuple=True):
|
| 107 |
+
path = f"{EGOBODY_ROOT}/slahmr/cameras_gt/{seq_name}/intrinsics.txt"
|
| 108 |
+
assert os.path.isfile(path)
|
| 109 |
+
intrins = np.loadtxt(path) # (T, 6)
|
| 110 |
+
end = len(intrins) if end < 0 else end
|
| 111 |
+
intrins = intrins[start:end]
|
| 112 |
+
if ret_size_tuple:
|
| 113 |
+
img_size = intrins[0, 4:].astype(int).tolist() # (2)
|
| 114 |
+
intrins = torch.from_numpy(intrins[:, :4].astype(np.float32))
|
| 115 |
+
return intrins, img_size
|
| 116 |
+
img_size = torch.from_numpy(intrins[:, 4:].astype(int))
|
| 117 |
+
intrins = torch.from_numpy(intrins[:, :4].astype(np.float32))
|
| 118 |
+
return intrins, img_size
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def load_egobody_gt_extrinsics(seq_name, start=0, end=-1, ret_4d=True):
|
| 122 |
+
path = f"{EGOBODY_ROOT}/slahmr/cameras_gt/{seq_name}/cam2world.txt"
|
| 123 |
+
assert os.path.isfile(path)
|
| 124 |
+
cam2world = np.loadtxt(path).astype(np.float32) # (T, 16)
|
| 125 |
+
end = len(cam2world) if end < 0 else end
|
| 126 |
+
cam2world = torch.from_numpy(cam2world[start:end].reshape(-1, 4, 4))
|
| 127 |
+
if ret_4d:
|
| 128 |
+
return cam2world
|
| 129 |
+
return cam2world[:, :3, :3], cam2world[:, :3, 3]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_egobody_extrinsics(seq_name, use_intrins=True, start=0, end=-1):
|
| 133 |
+
camera_name = "cameras_intrins" if use_intrins else "cameras_default"
|
| 134 |
+
path = f"{EGOBODY_ROOT}/slahmr/{camera_name}/{seq_name}/cameras.npz"
|
| 135 |
+
assert os.path.isfile(path)
|
| 136 |
+
data = np.load(path)
|
| 137 |
+
w2c = torch.from_numpy(data["w2c"].astype(np.float32)) # (N, 4, 4)
|
| 138 |
+
end = len(w2c) if end < 0 else end
|
| 139 |
+
w2c = w2c[start:end]
|
| 140 |
+
c2w = torch.linalg.inv(w2c)
|
| 141 |
+
return c2w[:, :3, :3], c2w[:, :3, 3]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def load_egobody_meshes(seq_name, device, start=0, end=-1):
|
| 145 |
+
params = load_egobody_smpl_params(seq_name, start=start, end=end)
|
| 146 |
+
_, T = params["trans"].shape[:2]
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
gender = params["genders"][0]
|
| 150 |
+
body_model = load_body_model(T, "smpl", gender, device)
|
| 151 |
+
smpl_res = body_model(
|
| 152 |
+
trans=params["trans"][0].to(device),
|
| 153 |
+
root_orient=params["root_orient"][0].to(device),
|
| 154 |
+
betas=params["betas"][0].to(device),
|
| 155 |
+
pose_body=params["pose_body"][0].to(device),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
res = {"joints": smpl_res.Jtr, "vertices": smpl_res.v, "faces": smpl_res.f}
|
| 159 |
+
return move_to(detach_all(res), "cpu")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def load_egobody_kinect2holo(seq_name, ret_4d=True):
|
| 163 |
+
# load the transform from kinect12 to holo
|
| 164 |
+
# bodies are recorded in the kinect12 frame
|
| 165 |
+
path = f"{EGOBODY_ROOT}/calibrations/{seq_name}/cal_trans/holo_to_kinect12.json"
|
| 166 |
+
with open(path, "r") as f:
|
| 167 |
+
kinect2holo = np.linalg.inv(np.array(json.load(f)["trans"]))
|
| 168 |
+
kinect2holo = torch.from_numpy(kinect2holo.astype(np.float32))
|
| 169 |
+
if ret_4d:
|
| 170 |
+
return kinect2holo
|
| 171 |
+
return kinect2holo[:3, :3], kinect2holo[:3, 3]
|
slahmr/slahmr/eval/run_eval.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import joblib
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
import egobody_utils as eb_util
|
| 12 |
+
from tools import (
|
| 13 |
+
load_body_model,
|
| 14 |
+
load_results_all,
|
| 15 |
+
local_align_joints,
|
| 16 |
+
global_align_joints,
|
| 17 |
+
first_align_joints,
|
| 18 |
+
compute_accel_norm,
|
| 19 |
+
run_smpl,
|
| 20 |
+
JointRegressor,
|
| 21 |
+
EGOBODY_ROOT,
|
| 22 |
+
TDPW_ROOT,
|
| 23 |
+
)
|
| 24 |
+
from associate import associate_phalp_track_dirs
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def stack_torch(x_list, dim=0):
|
| 28 |
+
return torch.stack(
|
| 29 |
+
[torch.from_numpy(x.astype(np.float32)) for x in x_list], dim=dim
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_3dpw_params(seq_name, start=0, end=-1):
|
| 34 |
+
seq_file = f"{TDPW_ROOT}/sequenceFiles/test/{seq_name}.pkl"
|
| 35 |
+
with open(seq_file, "rb") as f:
|
| 36 |
+
data = pickle.load(f, encoding="latin1")
|
| 37 |
+
|
| 38 |
+
M = len(data["poses"])
|
| 39 |
+
T = len(data["poses"][0])
|
| 40 |
+
end = T + 1 + end if end < 0 else end
|
| 41 |
+
T = end - start
|
| 42 |
+
trans = stack_torch([x[start:end] for x in data["trans"]]) # (M, T, 3)
|
| 43 |
+
poses = stack_torch([x[start:end] for x in data["poses"]]) # (M, T, 72)
|
| 44 |
+
betas = stack_torch([x[None, :10] for x in data["betas"]]).expand(
|
| 45 |
+
M, T, 10
|
| 46 |
+
) # (M, T, 10)
|
| 47 |
+
keypts2d = stack_torch([x[start:end] for x in data["poses2d"]]) # (M, T, 3, 18)
|
| 48 |
+
valid_cam = stack_torch(
|
| 49 |
+
[x[start:end] for x in data["campose_valid"]]
|
| 50 |
+
).bool() # (M, T)
|
| 51 |
+
valid_kp = (keypts2d.reshape(M, T, -1) > 0).any(dim=-1).bool() # (M, T)
|
| 52 |
+
valid = valid_cam & valid_kp
|
| 53 |
+
genders = ["male" if x == "m" else "female" for x in data["genders"]] # (M)
|
| 54 |
+
return {
|
| 55 |
+
"root_orient": poses[..., :3],
|
| 56 |
+
"pose_body": poses[..., 3:],
|
| 57 |
+
"trans": trans,
|
| 58 |
+
"betas": betas,
|
| 59 |
+
"keypts2d": keypts2d,
|
| 60 |
+
"valid": valid,
|
| 61 |
+
"genders": genders,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_egobody_params(seq_name, start=0, end=-1):
|
| 66 |
+
"""
|
| 67 |
+
returns dict of
|
| 68 |
+
- trans (1, T, 3)
|
| 69 |
+
- root_orient (1, T, 3)
|
| 70 |
+
- pose_body (1, T, 63)
|
| 71 |
+
- betas (1, T, 10)
|
| 72 |
+
- gender (str)
|
| 73 |
+
- keypts2d (1, T, J, 3)
|
| 74 |
+
- valid (1, T)
|
| 75 |
+
"""
|
| 76 |
+
smpl_dict = eb_util.load_egobody_smpl_params(seq_name, start=start, end=end)
|
| 77 |
+
kps, valid = eb_util.get_egobody_keypoints(seq_name, start=start, end=end)
|
| 78 |
+
smpl_dict["keypts2d"] = torch.from_numpy(kps.astype(np.float32))[None]
|
| 79 |
+
smpl_dict["valid"] = torch.from_numpy(valid.astype(bool))[None]
|
| 80 |
+
return smpl_dict
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def eval_result_dir(
|
| 84 |
+
dset_type, res_dir, out_path, joint_reg, dev_id=0, overwrite=False, debug=False
|
| 85 |
+
):
|
| 86 |
+
if os.path.isfile(out_path) and not overwrite:
|
| 87 |
+
print(f"{out_path} already exists, skipping.")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
# get the output metadata
|
| 91 |
+
track_file = f"{res_dir}/track_info.json"
|
| 92 |
+
if not os.path.isfile(track_file):
|
| 93 |
+
print(f"{track_file} does not exist, skipping")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
with open(track_file, "r") as f:
|
| 97 |
+
track_dict = json.load(f)
|
| 98 |
+
start, end = track_dict["meta"]["data_interval"]
|
| 99 |
+
seq_name = os.path.basename(res_dir).split("-")[0]
|
| 100 |
+
print("EVALUATING", res_dir, seq_name, start, end)
|
| 101 |
+
|
| 102 |
+
# get the associations from PHALP tracks to GT tracks
|
| 103 |
+
track_info = track_dict["tracks"]
|
| 104 |
+
track_ids = sorted(track_info, key=lambda k: track_info[k]["index"])
|
| 105 |
+
print("TRACK IDS", track_ids)
|
| 106 |
+
|
| 107 |
+
if dset_type == "egobody":
|
| 108 |
+
# load the GT params
|
| 109 |
+
gt_params = load_egobody_params(seq_name, start, end)
|
| 110 |
+
phalp_dir = f"{EGOBODY_ROOT}/slahmr/track_preds/{seq_name}"
|
| 111 |
+
img_dir = eb_util.get_egobody_img_dir(seq_name)
|
| 112 |
+
elif dset_type == "3dpw":
|
| 113 |
+
gt_params = load_3dpw_params(seq_name, start, end)
|
| 114 |
+
phalp_dir = f"{TDPW_ROOT}/slahmr/track_gt/{seq_name}"
|
| 115 |
+
img_dir = f"{TDPW_ROOT}/imageFiles/{seq_name}"
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError
|
| 118 |
+
|
| 119 |
+
# (M, T) GT track index for each frame and each PHALP track
|
| 120 |
+
match_idcs = associate_phalp_track_dirs(
|
| 121 |
+
phalp_dir,
|
| 122 |
+
img_dir,
|
| 123 |
+
track_ids,
|
| 124 |
+
gt_params["keypts2d"],
|
| 125 |
+
start=start,
|
| 126 |
+
end=end,
|
| 127 |
+
debug=debug,
|
| 128 |
+
)
|
| 129 |
+
# M number of PHALP tracks
|
| 130 |
+
M = len(track_ids)
|
| 131 |
+
|
| 132 |
+
# get the GT joints
|
| 133 |
+
G, T = gt_params["pose_body"].shape[:2]
|
| 134 |
+
device = torch.device(f"cuda:{dev_id}")
|
| 135 |
+
gt_joints = []
|
| 136 |
+
for g in range(G):
|
| 137 |
+
body_model = load_body_model(T, "smpl", gt_params["genders"][g], device)
|
| 138 |
+
gt_smpl = run_smpl(
|
| 139 |
+
body_model,
|
| 140 |
+
betas=gt_params["betas"][g].to(device),
|
| 141 |
+
trans=gt_params["trans"][g].to(device),
|
| 142 |
+
root_orient=gt_params["root_orient"][g].to(device),
|
| 143 |
+
pose_body=gt_params["pose_body"][g].to(device),
|
| 144 |
+
)
|
| 145 |
+
gt_joints.append(joint_reg(gt_smpl["vertices"])) # (T, 15, 3)
|
| 146 |
+
gt_joints = torch.stack(gt_joints, dim=0)
|
| 147 |
+
J, D = gt_joints.shape[-2:]
|
| 148 |
+
|
| 149 |
+
# select the correct GT person for each track
|
| 150 |
+
gt_valid = gt_params["valid"] # (G, T)
|
| 151 |
+
idcs = match_idcs.clone().reshape(M, T, 1, 1).expand(-1, -1, J, D)
|
| 152 |
+
idcs[idcs == -1] = 0 # gather dummy for invalid matches
|
| 153 |
+
gt_match_joints = torch.gather(gt_joints, 0, idcs)
|
| 154 |
+
gt_match_valid = torch.gather(gt_valid, 0, idcs[:, :, 0, 0])
|
| 155 |
+
valid = gt_match_valid & (match_idcs != -1)
|
| 156 |
+
|
| 157 |
+
# use the vis_mask to get the correct data subsequence
|
| 158 |
+
vis_mask = torch.tensor(
|
| 159 |
+
[track_info[tid]["vis_mask"] for tid in track_ids]
|
| 160 |
+
) # (M, T)
|
| 161 |
+
vis_tracks = torch.where(vis_mask.any(dim=1))[0] # (B,)
|
| 162 |
+
vis_idcs = torch.where(vis_mask.any(dim=0))[0]
|
| 163 |
+
sidx, eidx = vis_idcs.min(), vis_idcs.max() + 1
|
| 164 |
+
L = eidx - sidx
|
| 165 |
+
|
| 166 |
+
valid_seq = valid[vis_tracks, sidx:eidx] # (B, L)
|
| 167 |
+
gt_seq_joints = gt_match_joints[vis_tracks, sidx:eidx] # (B, L, *)
|
| 168 |
+
gt_seq_joints = gt_seq_joints[valid_seq]
|
| 169 |
+
|
| 170 |
+
if debug:
|
| 171 |
+
print(f"vis start {sidx}, end {eidx}, L {L}")
|
| 172 |
+
print("valid track matches", (match_idcs != -1).sum())
|
| 173 |
+
print("filtered gt joints", gt_seq_joints.shape)
|
| 174 |
+
|
| 175 |
+
# get the outputs of each phase
|
| 176 |
+
PHASES = ["root_fit", "smooth_fit", "motion_chunks"]
|
| 177 |
+
metric_names = ["ga_jmse", "fa_jmse", "pampjpe", "acc_norm"]
|
| 178 |
+
phase_metrics = {name: [-1 for _ in PHASE] for name in metric_names}
|
| 179 |
+
cur_metrics = {name: np.nan for name in metric_names}
|
| 180 |
+
for i, phase in enumerate(PHASES):
|
| 181 |
+
res_dict = load_results_all(os.path.join(res_dir, phase), device)
|
| 182 |
+
if res_dict is None:
|
| 183 |
+
print(f"PHASE {phase} did not optimize")
|
| 184 |
+
# update all metrics for this phase
|
| 185 |
+
for name in metric_names:
|
| 186 |
+
phase_metrics[name][i] = float(cur_metrics[name])
|
| 187 |
+
print(phase, phase_metrics)
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
# (M, L, -1, 3) verts, (M, L) mask
|
| 191 |
+
res_verts = res_dict["vertices"][valid_seq]
|
| 192 |
+
res_joints = joint_reg(res_verts) # (*, 15, 3_
|
| 193 |
+
|
| 194 |
+
for name in metric_names:
|
| 195 |
+
if name == "acc_norm":
|
| 196 |
+
target = compute_accel_norm(gt_seq_joints) # (T-2, J)
|
| 197 |
+
pred = compute_accel_norm(res_joints)
|
| 198 |
+
else:
|
| 199 |
+
target = gt_seq_joints
|
| 200 |
+
if name == "pampjpe":
|
| 201 |
+
pred = local_align_joints(gt_seq_joints, res_joints)
|
| 202 |
+
if name == "ga_jmse":
|
| 203 |
+
pred = global_align_joints(gt_seq_joints, res_joints)
|
| 204 |
+
if name == "fa_jmse":
|
| 205 |
+
pred = first_align_joints(gt_seq_joints, res_joints)
|
| 206 |
+
else:
|
| 207 |
+
raise NotImplementedError
|
| 208 |
+
cur_metrics[name] = torch.linalg.norm(target - pred, dim=-1).mean()
|
| 209 |
+
phase_metrics[name][i] = float(cur_metrics[name])
|
| 210 |
+
print(phase, name, cur_metrics[name])
|
| 211 |
+
|
| 212 |
+
df_dict = {"phases": PHASES}
|
| 213 |
+
df_dict.update(phase_metrics)
|
| 214 |
+
df = pd.DataFrame.from_dict(df_dict)
|
| 215 |
+
df.to_csv(out_path, index=False)
|
| 216 |
+
print(f"saved metrics to {out_path}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def parse_job_file(args):
|
| 220 |
+
subseq_names = []
|
| 221 |
+
with open(args.job_file, "r") as f:
|
| 222 |
+
for line in f.readlines():
|
| 223 |
+
cmd_args = line.strip().split()
|
| 224 |
+
seq_name, start_str, end_str = cmd_args[:3]
|
| 225 |
+
start = start_str.split("=")[-1]
|
| 226 |
+
end = end_str.split("=")[-1]
|
| 227 |
+
track_name = "longest-2" if args.dset_type == "3dpw" else "all"
|
| 228 |
+
if len(cmd_args) > 3:
|
| 229 |
+
track_name = cmd_args[3].split("=")[-1]
|
| 230 |
+
subseq_names.append(f"{seq_name}-{track_name}-{start}-{end}")
|
| 231 |
+
return subseq_names
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main(args):
|
| 235 |
+
joint_reg = JointRegressor()
|
| 236 |
+
out_root = args.out_root if args.out_root is not None else args.res_root
|
| 237 |
+
os.makedirs(out_root, exist_ok=True)
|
| 238 |
+
|
| 239 |
+
subseq_names = parse_job_file(args)
|
| 240 |
+
for subseq in subseq_names:
|
| 241 |
+
res_dir = os.path.join(args.res_root, subseq)
|
| 242 |
+
out_path = os.path.join(out_root, f"{subseq}.txt")
|
| 243 |
+
eval_result_dir(
|
| 244 |
+
args.dset_type,
|
| 245 |
+
res_dir,
|
| 246 |
+
out_path,
|
| 247 |
+
joint_reg,
|
| 248 |
+
overwrite=args.overwrite,
|
| 249 |
+
debug=args.debug,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
metric_paths = glob.glob(f"{out_root}/[!_]*.txt")
|
| 253 |
+
dfs = [pd.read_csv(path) for path in metric_paths]
|
| 254 |
+
|
| 255 |
+
merged = pd.concat(dfs).groupby("phase").mean()
|
| 256 |
+
merged.to_csv(f"{out_root}/_final_metrics.txt")
|
| 257 |
+
print(merged)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
import argparse
|
| 262 |
+
|
| 263 |
+
parser = argparse.ArgumentParser()
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"-d",
|
| 266 |
+
"--dset_type",
|
| 267 |
+
required=True,
|
| 268 |
+
choices=["egobody", "3dpw"],
|
| 269 |
+
help="dataset to evaluate on, choices: (3dpw, egobody)",
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"-i", "--res_root", required=True, help="root directory of outputs to evaluate"
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"-f",
|
| 276 |
+
"--job_file",
|
| 277 |
+
required=True,
|
| 278 |
+
help="job file specifying the examples to run and evaluate",
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"-o",
|
| 282 |
+
"--out_root",
|
| 283 |
+
default=None,
|
| 284 |
+
help="directory to save computed metrics, default is res_root",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument("-y", "--overwrite", action="store_true")
|
| 287 |
+
parser.add_argument("-d", "--debug", action="store_true")
|
| 288 |
+
args = parser.parse_args()
|
| 289 |
+
main(args)
|
slahmr/slahmr/eval/split_3dpw.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import itertools
|
| 4 |
+
import joblib
|
| 5 |
+
|
| 6 |
+
from eval_3dpw import load_3dpw_params
|
| 7 |
+
from associate import associate_frame
|
| 8 |
+
from tools import TDPW_ROOT
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
Script to find the associations of ground truth 3DPW tracks with the detected PHALP tracks
|
| 13 |
+
Will write a job specification file to ../job_specs with which track IDs to run optimization on
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
IMG_ROOT = f"{TDPW_ROOT}/imageFiles"
|
| 17 |
+
SRC_DIR = f"{TDPW_ROOT}/sequenceFiles"
|
| 18 |
+
PHALP_DIR = f"{TDPW_ROOT}/slahmr/phalp_out/results"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_split_sequences(split):
|
| 22 |
+
assert split in ["train", "val", "test"]
|
| 23 |
+
split_dir = f"{SRC_DIR}/{split}"
|
| 24 |
+
seq_files = sorted(os.listdir(split_dir))
|
| 25 |
+
return [os.path.splitext(f)[0] for f in seq_files]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def select_phalp_tracks(seq_name, split, start, end, debug=False):
|
| 29 |
+
"""
|
| 30 |
+
Select the best phalp track for each GT person for each frame.
|
| 31 |
+
Returns all phalp tracks that match GT over sequence
|
| 32 |
+
"""
|
| 33 |
+
phalp_file = f"{PHALP_DIR}/{seq_name}.pkl"
|
| 34 |
+
track_data = joblib.load(phalp_file)
|
| 35 |
+
img_names = sorted(track_data.keys())
|
| 36 |
+
sel_imgs = img_names[start:end]
|
| 37 |
+
|
| 38 |
+
gt_params = load_3dpw_params(f"{SRC_DIR}/{split}/{seq_name}.pkl", start, end)
|
| 39 |
+
gt_kps = gt_params["keypts2d"]
|
| 40 |
+
G, T = gt_kps.shape[:2] # G num people, T num frames
|
| 41 |
+
assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
|
| 42 |
+
|
| 43 |
+
track_ids = set()
|
| 44 |
+
for frame in sel_imgs:
|
| 45 |
+
frame_data = track_data[frame]
|
| 46 |
+
for tid in frame_data["tracked_ids"]:
|
| 47 |
+
track_ids.add(str(tid))
|
| 48 |
+
track_ids = list(track_ids)
|
| 49 |
+
M = len(track_ids)
|
| 50 |
+
track_idcs = {tid: m for m, tid in enumerate(track_ids)}
|
| 51 |
+
|
| 52 |
+
# get the best matching PHALP track for each GT person
|
| 53 |
+
sel_tracks = set()
|
| 54 |
+
for t, frame_name in enumerate(sel_imgs):
|
| 55 |
+
frame_data = track_data[frame_name]
|
| 56 |
+
for g in range(G):
|
| 57 |
+
kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
|
| 58 |
+
# get the best track ID for the GT person
|
| 59 |
+
tid = associate_frame(frame_data, kp_gt, track_ids, debug=debug)
|
| 60 |
+
if tid == -1:
|
| 61 |
+
continue
|
| 62 |
+
sel_tracks.add(int(tid))
|
| 63 |
+
return list(sel_tracks)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
import argparse
|
| 68 |
+
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
parser.add_argument("--seq_len", type=int, default=100)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--split", default="test", choices=["train", "val", "test", "all"]
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument("--prefix", default="3dpw")
|
| 75 |
+
args = parser.parse_args()
|
| 76 |
+
|
| 77 |
+
seqs = load_split_sequences(args.split)
|
| 78 |
+
|
| 79 |
+
job_arg_strs = []
|
| 80 |
+
for seq in seqs:
|
| 81 |
+
num_imgs = len(glob.glob(f"{IMG_ROOT}/{seq}/*.jpg"))
|
| 82 |
+
splits = list(range(0, num_imgs, args.seq_len))
|
| 83 |
+
splits[-1] = num_imgs # just add the remainder to the last job
|
| 84 |
+
for start, end in zip(splits[:-1], splits[1:]):
|
| 85 |
+
sel_tracks = select_phalp_tracks(seq, args.split, start, end)
|
| 86 |
+
if len(sel_tracks) < 1:
|
| 87 |
+
continue
|
| 88 |
+
track_str = "-".join([f"{tid:03d}" for tid in sel_tracks])
|
| 89 |
+
arg_str = (
|
| 90 |
+
f"{seq} data.start_idx={start} data.end_idx={end} "
|
| 91 |
+
f"data.track_ids={track_str}"
|
| 92 |
+
)
|
| 93 |
+
print(arg_str)
|
| 94 |
+
job_arg_strs.append(arg_str)
|
| 95 |
+
|
| 96 |
+
with open(
|
| 97 |
+
f"../job_specs/{args.prefix}_{args.split}_len_{args.seq_len}.txt", "w"
|
| 98 |
+
) as f:
|
| 99 |
+
f.write("\n".join(job_arg_strs))
|
slahmr/slahmr/eval/split_egobody.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import joblib
|
| 4 |
+
import itertools
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from associate import associate_frame
|
| 10 |
+
from tools import EGOBODY_ROOT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Script to find the associations of ground truth Egobody tracks with the detected PHALP tracks
|
| 15 |
+
Will write a job specification file to ../job_specs with which track IDs to run optimization on
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
IMG_ROOT = f"{EGOBODY_ROOT}/egocentric_color"
|
| 19 |
+
PHALP_DIR = f"{EGOBODY_ROOT}/slahmr/phalp_out/results"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_split_sequences(split):
|
| 23 |
+
split_file = "{EGOBODY_ROOT}/data_splits.csv"
|
| 24 |
+
df = pd.read_csv(split_file)
|
| 25 |
+
if split not in df.columns:
|
| 26 |
+
print(f"{split} not in {split_file}")
|
| 27 |
+
return []
|
| 28 |
+
return df[split].dropna().tolist()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_egobody_keypoints(img_dir, start, end):
|
| 32 |
+
kp_file = f"{img_dir}/keypoints.npz"
|
| 33 |
+
valid_file = f"{img_dir}/valid_frame.npz"
|
| 34 |
+
img_paths = sorted(glob.glob(f"{img_dir}/PV/*.jpg"))[start:end]
|
| 35 |
+
img_names = [os.path.basename(x) for x in img_paths]
|
| 36 |
+
|
| 37 |
+
kp_dict = {}
|
| 38 |
+
valid_dict = {}
|
| 39 |
+
kp_data = np.load(kp_file)
|
| 40 |
+
valid_data = np.load(valid_file)
|
| 41 |
+
|
| 42 |
+
zeros = np.zeros_like(kp_data["keypoints"][0])
|
| 43 |
+
for img_path, kps in zip(kp_data["imgname"], kp_data["keypoints"]):
|
| 44 |
+
img_name = os.path.basename(img_path)
|
| 45 |
+
kp_dict[img_name] = kps
|
| 46 |
+
|
| 47 |
+
for img_path, valid in zip(valid_data["imgname"], valid_data["valid"]):
|
| 48 |
+
img_name = os.path.basename(img_path)
|
| 49 |
+
valid_dict[img_name] = valid
|
| 50 |
+
|
| 51 |
+
kps = np.stack([kp_dict.get(name, zeros) for name in img_names], axis=0)
|
| 52 |
+
valid = np.stack([valid_dict.get(name, False) for name in img_names], axis=0)
|
| 53 |
+
return kps, valid
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def select_phalp_tracks(seq_name, img_dir, start, end, debug=False):
|
| 57 |
+
"""
|
| 58 |
+
Get the best phalp track for each GT person for each frame
|
| 59 |
+
Returns all phalp tracks that match GT over sequence
|
| 60 |
+
"""
|
| 61 |
+
phalp_file = f"{PHALP_DIR}/{seq_name}.pkl"
|
| 62 |
+
track_data = joblib.load(phalp_file)
|
| 63 |
+
img_names = sorted(track_data.keys())
|
| 64 |
+
sel_imgs = img_names[start:end]
|
| 65 |
+
|
| 66 |
+
kps_all, valid = get_egobody_keypoints(img_dir, start, end)
|
| 67 |
+
T = len(kps_all)
|
| 68 |
+
assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
|
| 69 |
+
|
| 70 |
+
track_ids = set()
|
| 71 |
+
for frame in sel_imgs:
|
| 72 |
+
frame_data = track_data[frame]
|
| 73 |
+
for tid in frame_data["tracked_ids"]:
|
| 74 |
+
track_ids.add(str(tid))
|
| 75 |
+
track_ids = list(track_ids)
|
| 76 |
+
M = len(track_ids)
|
| 77 |
+
track_idcs = {tid: m for m, tid in enumerate(track_ids)}
|
| 78 |
+
|
| 79 |
+
# get the best matching PHALP track for each GT person
|
| 80 |
+
sel_tracks = set()
|
| 81 |
+
for t, frame_name in enumerate(sel_imgs):
|
| 82 |
+
frame_data = track_data[frame_name]
|
| 83 |
+
# get the best track ID for the GT person
|
| 84 |
+
tid = associate_frame(frame_data, kps_all[t], track_ids, debug=debug)
|
| 85 |
+
if tid == -1:
|
| 86 |
+
continue
|
| 87 |
+
sel_tracks.add(int(tid))
|
| 88 |
+
return list(sel_tracks)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
import argparse
|
| 93 |
+
|
| 94 |
+
parser = argparse.ArgumentParser()
|
| 95 |
+
parser.add_argument("--seq_len", type=int, default=100)
|
| 96 |
+
parser.add_argument("--split", default="val", choices=["train", "val", "test"])
|
| 97 |
+
parser.add_argument("--prefix", default="ego")
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
seqs = load_split_sequences(args.split)
|
| 101 |
+
|
| 102 |
+
job_arg_strs = []
|
| 103 |
+
for seq in seqs:
|
| 104 |
+
img_dir = glob.glob(f"{IMG_ROOT}/{seq}/**/")[0]
|
| 105 |
+
num_imgs = len(glob.glob(f"{img_dir}/PV/*.jpg"))
|
| 106 |
+
splits = list(range(0, num_imgs, args.seq_len))
|
| 107 |
+
splits[-1] = num_imgs # just add the remainder to the last job
|
| 108 |
+
for start, end in zip(splits[:-1], splits[1:]):
|
| 109 |
+
sel_tracks = select_phalp_tracks(seq, img_dir, start, end)
|
| 110 |
+
if len(sel_tracks) < 1:
|
| 111 |
+
continue
|
| 112 |
+
track_str = "-".join([f"{tid:03d}" for tid in sel_tracks])
|
| 113 |
+
arg_str = (
|
| 114 |
+
f"{seq} data.start_idx={start} data.end_idx={end} "
|
| 115 |
+
f"data.track_ids={track_str}"
|
| 116 |
+
)
|
| 117 |
+
print(arg_str)
|
| 118 |
+
job_arg_strs.append(arg_str)
|
| 119 |
+
|
| 120 |
+
with open(
|
| 121 |
+
f"../job_specs/{args.prefix}_{args.split}_len_{args.seq_len}_tracks.txt", "w"
|
| 122 |
+
) as f:
|
| 123 |
+
f.write("\n".join(job_arg_strs))
|
slahmr/slahmr/eval/tools.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import glob
|
| 4 |
+
import json
|
| 5 |
+
import joblib
|
| 6 |
+
import numpy as np
|
| 7 |
+
import smplx
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from util.loaders import load_smpl_body_model
|
| 11 |
+
from util.tensor import move_to, detach_all, to_torch
|
| 12 |
+
from optim.output import load_result, get_results_paths
|
| 13 |
+
from geometry.pcl import align_pcl
|
| 14 |
+
from geometry.rotation import batch_rodrigues
|
| 15 |
+
|
| 16 |
+
BASE_DIR = os.path.abspath(f"{__file__}/../../../")
|
| 17 |
+
JOINT_REG_PATH = f"{BASE_DIR}/_DATA/body_models/J_regressor_h36m.npy"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# XXX: Sorry, need to change this yourself
|
| 21 |
+
EGOBODY_ROOT = "/path/to/egobody"
|
| 22 |
+
TDPW_ROOT = "/path/to/3DPW"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class JointRegressor(object):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
# (17, 6890)
|
| 28 |
+
R17 = torch.from_numpy(np.load(JOINT_REG_PATH).astype(np.float32))
|
| 29 |
+
# (14,) adding the root, but will omit
|
| 30 |
+
joint_map_h36m = torch.tensor([6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10])
|
| 31 |
+
self.regressor = R17[joint_map_h36m] # (14, 6890)
|
| 32 |
+
|
| 33 |
+
def to(self, device):
|
| 34 |
+
self.regressor = self.regressor.to(device)
|
| 35 |
+
|
| 36 |
+
def __call__(self, verts):
|
| 37 |
+
"""
|
| 38 |
+
NOTE: RETURNS ROOT AS WELL
|
| 39 |
+
:param verts (*, V, 3)
|
| 40 |
+
returns (*, J, 3) 14 standard evaluation joints
|
| 41 |
+
"""
|
| 42 |
+
return torch.einsum("nv,...vd->...nd", self.regressor, verts) # (..., 14, 3)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def compute_accel_norm(joints):
|
| 46 |
+
"""
|
| 47 |
+
:param joints (T, J, 3)
|
| 48 |
+
"""
|
| 49 |
+
vel = joints[1:] - joints[:-1] # (T-1, J, 3)
|
| 50 |
+
acc = vel[1:] - vel[:-1] # (T-2, J, 3)
|
| 51 |
+
return torch.linalg.norm(acc, dim=-1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def global_align_joints(gt_joints, pred_joints):
|
| 55 |
+
"""
|
| 56 |
+
:param gt_joints (T, J, 3)
|
| 57 |
+
:param pred_joints (T, J, 3)
|
| 58 |
+
"""
|
| 59 |
+
s_glob, R_glob, t_glob = align_pcl(
|
| 60 |
+
gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
|
| 61 |
+
)
|
| 62 |
+
pred_glob = (
|
| 63 |
+
s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
|
| 64 |
+
)
|
| 65 |
+
return pred_glob
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def first_align_joints(gt_joints, pred_joints):
|
| 69 |
+
"""
|
| 70 |
+
align the first two frames
|
| 71 |
+
:param gt_joints (T, J, 3)
|
| 72 |
+
:param pred_joints (T, J, 3)
|
| 73 |
+
"""
|
| 74 |
+
# (1, 1), (1, 3, 3), (1, 3)
|
| 75 |
+
s_first, R_first, t_first = align_pcl(
|
| 76 |
+
gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
|
| 77 |
+
)
|
| 78 |
+
pred_first = (
|
| 79 |
+
s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
|
| 80 |
+
)
|
| 81 |
+
return pred_first
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def local_align_joints(gt_joints, pred_joints):
|
| 85 |
+
"""
|
| 86 |
+
:param gt_joints (T, J, 3)
|
| 87 |
+
:param pred_joints (T, J, 3)
|
| 88 |
+
"""
|
| 89 |
+
s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints)
|
| 90 |
+
pred_loc = (
|
| 91 |
+
s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints)
|
| 92 |
+
+ t_loc[:, None]
|
| 93 |
+
)
|
| 94 |
+
return pred_loc
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_body_model(batch_size, model_type, gender, device):
|
| 98 |
+
assert model_type in ["smpl", "smplh"]
|
| 99 |
+
if model_type == "smpl":
|
| 100 |
+
num_betas = 10
|
| 101 |
+
ext = "pkl"
|
| 102 |
+
use_vtx_selector = False
|
| 103 |
+
else:
|
| 104 |
+
num_betas = 16
|
| 105 |
+
ext = "npz"
|
| 106 |
+
use_vtx_selector = True
|
| 107 |
+
|
| 108 |
+
smpl_path = f"{BASE_DIR}/body_models/{model_type}/{gender}/model.{ext}"
|
| 109 |
+
body_model, fit_gender = load_smpl_body_model(
|
| 110 |
+
smpl_path,
|
| 111 |
+
batch_size,
|
| 112 |
+
num_betas,
|
| 113 |
+
model_type=model_type,
|
| 114 |
+
use_vtx_selector=use_vtx_selector,
|
| 115 |
+
device=device,
|
| 116 |
+
)
|
| 117 |
+
return body_model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def run_smpl(body_model, *args, **kwargs):
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
results = body_model(*args, **kwargs)
|
| 123 |
+
return {
|
| 124 |
+
"joints": results.Jtr.detach().cpu(),
|
| 125 |
+
"vertices": results.v.detach().cpu(),
|
| 126 |
+
"faces": results.f.detach().cpu(),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_smpl_batch(body_model, device, **kwargs):
|
| 131 |
+
model_kwargs = {}
|
| 132 |
+
B = body_model.bm.batch_size
|
| 133 |
+
kwarg_shape = (B,)
|
| 134 |
+
for k, v in kwargs.items():
|
| 135 |
+
kwarg_shape = v.shape[:-1]
|
| 136 |
+
model_kwargs[k] = v.reshape(B, v.shape[-1]).to(device)
|
| 137 |
+
res_flat = run_smpl(body_model, **model_kwargs)
|
| 138 |
+
res = {}
|
| 139 |
+
for k, v in res_flat.items():
|
| 140 |
+
sh = v.shape
|
| 141 |
+
if sh[0] == B:
|
| 142 |
+
v = v.reshape(*kwarg_shape, *sh[1:])
|
| 143 |
+
res[k] = v
|
| 144 |
+
return res
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def cat_dicts(dict_list, dim=0):
|
| 148 |
+
"""
|
| 149 |
+
concatenate lists of dict of tensors
|
| 150 |
+
"""
|
| 151 |
+
keys = set(dict_list[0].keys())
|
| 152 |
+
assert all(keys == set(d.keys()) for d in dict_list)
|
| 153 |
+
return {k: torch.stack([d[k] for d in dict_list], dim=dim) for k in keys}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def load_results_all(phase_dir, device):
|
| 157 |
+
"""
|
| 158 |
+
Load all the reconstructed tracks during optimization
|
| 159 |
+
"""
|
| 160 |
+
res_path_dict = get_results_paths(phase_dir)
|
| 161 |
+
max_iter = max(res_path_dict.keys())
|
| 162 |
+
if int(max_iter) < 20:
|
| 163 |
+
print("max_iter", max_iter)
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
res = load_result(res_path_dict[max_iter])["world"]
|
| 167 |
+
# results is dict with (B, T, *) tensors
|
| 168 |
+
trans = res["trans"]
|
| 169 |
+
B, T, _ = trans.shape
|
| 170 |
+
root_orient = res["root_orient"]
|
| 171 |
+
pose_body = res["pose_body"]
|
| 172 |
+
betas = res["betas"].reshape(B, 1, -1).expand(B, T, -1)
|
| 173 |
+
body_model = load_body_model(B * T, "smplh", "neutral", device)
|
| 174 |
+
return run_smpl_batch(
|
| 175 |
+
body_model,
|
| 176 |
+
device,
|
| 177 |
+
trans=trans,
|
| 178 |
+
root_orient=root_orient,
|
| 179 |
+
betas=betas,
|
| 180 |
+
pose_body=pose_body,
|
| 181 |
+
)
|
slahmr/slahmr/geometry/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import camera
|
| 2 |
+
from . import mesh
|
| 3 |
+
from . import pcl
|
| 4 |
+
from . import plane
|
| 5 |
+
from . import rotation
|
slahmr/slahmr/geometry/camera.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def perspective_projection(
|
| 6 |
+
points, focal_length, camera_center, rotation=None, translation=None
|
| 7 |
+
):
|
| 8 |
+
"""
|
| 9 |
+
Adapted from https://github.com/mkocabas/VIBE/blob/master/lib/models/spin.py
|
| 10 |
+
This function computes the perspective projection of a set of points.
|
| 11 |
+
Input:
|
| 12 |
+
points (bs, N, 3): 3D points
|
| 13 |
+
focal_length (bs, 2): Focal length
|
| 14 |
+
camera_center (bs, 2): Camera center
|
| 15 |
+
rotation (bs, 3, 3): OPTIONAL Camera rotation
|
| 16 |
+
translation (bs, 3): OPTIONAL Camera translation
|
| 17 |
+
"""
|
| 18 |
+
batch_size = points.shape[0]
|
| 19 |
+
K = torch.zeros([batch_size, 3, 3], device=points.device)
|
| 20 |
+
K[:, 0, 0] = focal_length[:, 0]
|
| 21 |
+
K[:, 1, 1] = focal_length[:, 1]
|
| 22 |
+
K[:, 2, 2] = 1.0
|
| 23 |
+
K[:, :-1, -1] = camera_center
|
| 24 |
+
|
| 25 |
+
if rotation is not None and translation is not None:
|
| 26 |
+
# Transform points
|
| 27 |
+
points = torch.einsum("bij,bkj->bki", rotation, points)
|
| 28 |
+
points = points + translation.unsqueeze(1)
|
| 29 |
+
|
| 30 |
+
# Apply perspective distortion
|
| 31 |
+
projected_points = points / points[..., 2:3]
|
| 32 |
+
|
| 33 |
+
# Apply camera intrinsics
|
| 34 |
+
projected_points = torch.einsum("bij,bkj->bki", K, projected_points)
|
| 35 |
+
|
| 36 |
+
return projected_points[:, :, :-1]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def reproject(points3d, cam_R, cam_t, cam_f, cam_center):
|
| 40 |
+
"""
|
| 41 |
+
reproject points3d into the scene cameras
|
| 42 |
+
:param points3d (B, T, N, 3)
|
| 43 |
+
:param cam_R (B, T, 3, 3)
|
| 44 |
+
:param cam_t (B, T, 3)
|
| 45 |
+
:param cam_f (T, 2)
|
| 46 |
+
:param cam_center (T, 2)
|
| 47 |
+
"""
|
| 48 |
+
B, T, N, _ = points3d.shape
|
| 49 |
+
points3d = torch.einsum("btij,btnj->btni", cam_R, points3d)
|
| 50 |
+
points3d = points3d + cam_t[..., None, :] # (B, T, N, 3)
|
| 51 |
+
points2d = points3d[..., :2] / points3d[..., 2:3]
|
| 52 |
+
points2d = cam_f[None, :, None] * points2d + cam_center[None, :, None]
|
| 53 |
+
return points2d
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def focal2fov(focal, R):
|
| 57 |
+
"""
|
| 58 |
+
:param focal, focal length
|
| 59 |
+
:param R, either W / 2 or H / 2
|
| 60 |
+
"""
|
| 61 |
+
return 2 * np.arctan(R / focal)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def fov2focal(fov, R):
|
| 65 |
+
"""
|
| 66 |
+
:param fov, field of view in radians
|
| 67 |
+
:param R, either W / 2 or H / 2
|
| 68 |
+
"""
|
| 69 |
+
return R / np.tan(fov / 2)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def compute_lookat_box(bb_min, bb_max, intrins):
|
| 73 |
+
"""
|
| 74 |
+
The center and distance to a scene with bb_min, bb_max
|
| 75 |
+
to place a camera with given intrinsics
|
| 76 |
+
:param bb_min (3,)
|
| 77 |
+
:param bb_max (3,)
|
| 78 |
+
:param intrinsics, (fx, fy, cx, cy) of camera
|
| 79 |
+
:param view_angle (optional) viewing angle in radians (elevation)
|
| 80 |
+
"""
|
| 81 |
+
fx, fy, cx, cy = intrins
|
| 82 |
+
bb_min, bb_max = torch.tensor(bb_min), torch.tensor(bb_max)
|
| 83 |
+
center = 0.5 * (bb_min + bb_max)
|
| 84 |
+
size = torch.linalg.norm(bb_max - bb_min)
|
| 85 |
+
cam_dist = np.sqrt(fx**2 + fy**2) / np.sqrt(cx**2 + cy**2)
|
| 86 |
+
cam_dist = 0.75 * size * cam_dist
|
| 87 |
+
return center, cam_dist
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def lookat_origin(cam_dist, view_angle=-np.pi / 6):
|
| 91 |
+
"""
|
| 92 |
+
:param cam_dist (float)
|
| 93 |
+
:param view_angle (float)
|
| 94 |
+
"""
|
| 95 |
+
cam_dist = np.abs(cam_dist)
|
| 96 |
+
view_angle = np.abs(view_angle)
|
| 97 |
+
pos = cam_dist * torch.tensor([0, np.sin(view_angle), np.cos(view_angle)])
|
| 98 |
+
rot = rotx(view_angle)
|
| 99 |
+
return rot, pos
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def lookat_matrix(source_pos, target_pos, up):
|
| 103 |
+
"""
|
| 104 |
+
IMPORTANT: USES RIGHT UP BACK XYZ CONVENTION
|
| 105 |
+
:param source_pos (*, 3)
|
| 106 |
+
:param target_pos (*, 3)
|
| 107 |
+
:param up (3,)
|
| 108 |
+
"""
|
| 109 |
+
*dims, _ = source_pos.shape
|
| 110 |
+
up = up.reshape(*(1,) * len(dims), 3)
|
| 111 |
+
up = up / torch.linalg.norm(up, dim=-1, keepdim=True)
|
| 112 |
+
back = normalize(target_pos - source_pos)
|
| 113 |
+
right = normalize(torch.linalg.cross(up, back))
|
| 114 |
+
up = normalize(torch.linalg.cross(back, right))
|
| 115 |
+
R = torch.stack([right, up, back], dim=-1)
|
| 116 |
+
return make_4x4_pose(R, source_pos)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def normalize(x):
|
| 120 |
+
return x / torch.linalg.norm(x, dim=-1, keepdim=True)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def invert_camera(R, t):
|
| 124 |
+
"""
|
| 125 |
+
:param R (*, 3, 3)
|
| 126 |
+
:param t (*, 3)
|
| 127 |
+
returns Ri (*, 3, 3), ti (*, 3)
|
| 128 |
+
"""
|
| 129 |
+
R, t = torch.tensor(R), torch.tensor(t)
|
| 130 |
+
Ri = R.transpose(-1, -2)
|
| 131 |
+
ti = -torch.einsum("...ij,...j->...i", Ri, t)
|
| 132 |
+
return Ri, ti
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def compose_cameras(R1, t1, R2, t2):
|
| 136 |
+
"""
|
| 137 |
+
composes [R1, t1] and [R2, t2]
|
| 138 |
+
:param R1 (*, 3, 3)
|
| 139 |
+
:param t1 (*, 3)
|
| 140 |
+
:param R2 (*, 3, 3)
|
| 141 |
+
:param t2 (*, 3)
|
| 142 |
+
"""
|
| 143 |
+
R = torch.einsum("...ij,...jk->...ik", R1, R2)
|
| 144 |
+
t = t1 + torch.einsum("...ij,...j->...i", R1, t2)
|
| 145 |
+
return R, t
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def matmul_nd(A, x):
|
| 149 |
+
"""
|
| 150 |
+
multiply batch matrix A to batch nd tensors
|
| 151 |
+
:param A (B, m, n)
|
| 152 |
+
:param x (B, *dims, m)
|
| 153 |
+
"""
|
| 154 |
+
B, m, n = A.shape
|
| 155 |
+
assert len(A) == len(x)
|
| 156 |
+
assert x.shape[-1] == m
|
| 157 |
+
B, *dims, _ = x.shape
|
| 158 |
+
return torch.matmul(A.reshape(B, *(1,) * len(dims), m, n), x[..., None])[..., 0]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def view_matrix(z, up, pos):
|
| 162 |
+
"""
|
| 163 |
+
:param z (*, 3) up (*, 3) pos (*, 3)
|
| 164 |
+
returns (*, 4, 4)
|
| 165 |
+
"""
|
| 166 |
+
*dims, _ = z.shape
|
| 167 |
+
x = normalize(torch.linalg.cross(up, z))
|
| 168 |
+
y = normalize(torch.linalg.cross(z, x))
|
| 169 |
+
bottom = (
|
| 170 |
+
torch.tensor([0, 0, 0, 1], dtype=torch.float32)
|
| 171 |
+
.reshape(*(1,) * len(dims), 1, 4)
|
| 172 |
+
.expand(*dims, 1, 4)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return torch.cat([torch.stack([x, y, z, pos], dim=-1), bottom], dim=-2)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def average_pose(poses):
|
| 179 |
+
"""
|
| 180 |
+
:param poses (N, 4, 4)
|
| 181 |
+
returns average pose (4, 4)
|
| 182 |
+
"""
|
| 183 |
+
center = poses[:, :3, 3].mean(0)
|
| 184 |
+
up = normalize(poses[:, :3, 1].sum(0))
|
| 185 |
+
z = normalize(poses[:, :3, 2].sum(0))
|
| 186 |
+
return view_matrix(z, up, center)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def project_so3(M, eps=1e-4):
|
| 190 |
+
"""
|
| 191 |
+
:param M (N, *, 3, 3)
|
| 192 |
+
"""
|
| 193 |
+
N, *dims, _, _ = M.shape
|
| 194 |
+
M = M * (1 + torch.rand(N, *dims, 1, 3, device=M.device))
|
| 195 |
+
U, D, Vt = torch.linalg.svd(M) # (N, *, 3, 3), (N, *, 3), (N, *, 3, 3)
|
| 196 |
+
detuvt = torch.linalg.det(torch.matmul(U, Vt)) # (N, *)
|
| 197 |
+
S = torch.cat(
|
| 198 |
+
[torch.ones(N, *dims, 2, device=M.device), detuvt[..., None]], dim=-1
|
| 199 |
+
) # (N, *, 3)
|
| 200 |
+
return torch.matmul(U, torch.matmul(torch.diag_embed(S), Vt))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def make_translation(t):
|
| 204 |
+
return make_4x4_pose(torch.eye(3), t)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
|
| 208 |
+
Rx = rotx(rx)
|
| 209 |
+
Ry = roty(ry)
|
| 210 |
+
Rz = rotz(rz)
|
| 211 |
+
if order == "xyz":
|
| 212 |
+
R = Rz @ Ry @ Rx
|
| 213 |
+
elif order == "xzy":
|
| 214 |
+
R = Ry @ Rz @ Rx
|
| 215 |
+
elif order == "yxz":
|
| 216 |
+
R = Rz @ Rx @ Ry
|
| 217 |
+
elif order == "yzx":
|
| 218 |
+
R = Rx @ Rz @ Ry
|
| 219 |
+
elif order == "zyx":
|
| 220 |
+
R = Rx @ Ry @ Rz
|
| 221 |
+
elif order == "zxy":
|
| 222 |
+
R = Ry @ Rx @ Rz
|
| 223 |
+
return make_4x4_pose(R, torch.zeros(3))
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def make_4x4_pose(R, t):
|
| 227 |
+
"""
|
| 228 |
+
:param R (*, 3, 3)
|
| 229 |
+
:param t (*, 3)
|
| 230 |
+
return (*, 4, 4)
|
| 231 |
+
"""
|
| 232 |
+
dims = R.shape[:-2]
|
| 233 |
+
pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
|
| 234 |
+
bottom = (
|
| 235 |
+
torch.tensor([0, 0, 0, 1], device=R.device)
|
| 236 |
+
.reshape(*(1,) * len(dims), 1, 4)
|
| 237 |
+
.expand(*dims, 1, 4)
|
| 238 |
+
)
|
| 239 |
+
return torch.cat([pose_3x4, bottom], dim=-2)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def normalize(x):
|
| 243 |
+
return x / torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def rotx(theta):
|
| 247 |
+
return torch.tensor(
|
| 248 |
+
[
|
| 249 |
+
[1, 0, 0],
|
| 250 |
+
[0, np.cos(theta), -np.sin(theta)],
|
| 251 |
+
[0, np.sin(theta), np.cos(theta)],
|
| 252 |
+
],
|
| 253 |
+
dtype=torch.float32,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def roty(theta):
|
| 258 |
+
return torch.tensor(
|
| 259 |
+
[
|
| 260 |
+
[np.cos(theta), 0, np.sin(theta)],
|
| 261 |
+
[0, 1, 0],
|
| 262 |
+
[-np.sin(theta), 0, np.cos(theta)],
|
| 263 |
+
],
|
| 264 |
+
dtype=torch.float32,
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def rotz(theta):
|
| 269 |
+
return torch.tensor(
|
| 270 |
+
[
|
| 271 |
+
[np.cos(theta), -np.sin(theta), 0],
|
| 272 |
+
[np.sin(theta), np.cos(theta), 0],
|
| 273 |
+
[0, 0, 1],
|
| 274 |
+
],
|
| 275 |
+
dtype=torch.float32,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def relative_pose_c2w(Rwc1, Rwc2, twc1, twc2):
|
| 280 |
+
"""
|
| 281 |
+
compute relative pose from cam 1 to cam 2 given c2w pose matrices
|
| 282 |
+
:param Rwc1, Rwc2 (N, 3, 3) cam1, cam2 to world rotations
|
| 283 |
+
:param twc1, twc2 (N, 3) cam1, cam2 to world translations
|
| 284 |
+
returns R21 (N, 3, 3) t21 (N, 3)
|
| 285 |
+
"""
|
| 286 |
+
twc1 = twc1.view(-1, 3, 1)
|
| 287 |
+
twc2 = twc2.view(-1, 3, 1)
|
| 288 |
+
Rc2w = Rwc2.transpose(-1, -2) # world to c2
|
| 289 |
+
tc2w = -torch.matmul(Rc2w, twc2)
|
| 290 |
+
Rc2c1 = torch.matmul(Rc2w, Rwc1)
|
| 291 |
+
tc2c1 = tc2w + torch.matmul(Rc2w, twc1)
|
| 292 |
+
return Rc2c1, tc2c1[..., 0]
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def relative_pose_w2c(Rc1w, Rc2w, tc1w, tc2w):
|
| 296 |
+
"""
|
| 297 |
+
compute relative pose from cam 1 to cam 2 given w2c camera matrices
|
| 298 |
+
:param Rc1w, Rc2w (N, 3, 3) world to cam1, cam2 rotations
|
| 299 |
+
:param tc1w, tc2w (N, 3) world to cam1, cam2 translations
|
| 300 |
+
"""
|
| 301 |
+
tc1w = tc1w.view(-1, 3, 1)
|
| 302 |
+
tc2w = tc2w.view(-1, 3, 1)
|
| 303 |
+
# we keep the world to cam transforms
|
| 304 |
+
Rwc1 = Rc1w.transpose(-1, -2) # c1 to world
|
| 305 |
+
twc1 = -torch.matmul(Rwc1, tc1w)
|
| 306 |
+
Rc2c1 = torch.matmul(Rc2w, Rwc1) # c1 to c2
|
| 307 |
+
tc2c1 = tc2w + torch.matmul(Rc2w, twc1)
|
| 308 |
+
return Rc2c1, tc2c1[..., 0]
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def project(xyz_c, center, focal, eps=1e-5):
|
| 312 |
+
"""
|
| 313 |
+
:param xyz_c (*, 3) 3d point in camera coordinates
|
| 314 |
+
:param focal (1)
|
| 315 |
+
:param center (*, 2)
|
| 316 |
+
return (*, 2)
|
| 317 |
+
"""
|
| 318 |
+
return focal * xyz_c[..., :2] / (xyz_c[..., 2:3] + eps) + center # (N, *, 2)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def convert_yup(xyz):
|
| 322 |
+
"""
|
| 323 |
+
converts points in x right y down z forward to x right y up z back
|
| 324 |
+
:param xyz (*, 3)
|
| 325 |
+
"""
|
| 326 |
+
x, y, z = torch.split(xyz[..., :3], 1, dim=-1)
|
| 327 |
+
return torch.cat([x, -y, -z], dim=-1)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def inv_project(uv, z, center, focal, yup=True):
|
| 331 |
+
"""
|
| 332 |
+
:param uv (*, 2)
|
| 333 |
+
:param z (*, 1)
|
| 334 |
+
:param center (*, 2)
|
| 335 |
+
:param focal (1)
|
| 336 |
+
:returns (*, 3)
|
| 337 |
+
"""
|
| 338 |
+
uv = uv - center
|
| 339 |
+
if yup:
|
| 340 |
+
return z * torch.cat(
|
| 341 |
+
[uv[..., :1] / focal, -uv[..., 1:2] / focal, -torch.ones_like(uv[..., :1])],
|
| 342 |
+
dim=-1,
|
| 343 |
+
) # (N, *, 3)
|
| 344 |
+
|
| 345 |
+
return z * torch.cat(
|
| 346 |
+
[uv / focal, torch.ones_like(uv[..., :1])],
|
| 347 |
+
dim=-1,
|
| 348 |
+
) # (N, *, 3)
|
slahmr/slahmr/geometry/mesh.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import trimesh
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_mesh_bb(mesh):
|
| 7 |
+
"""
|
| 8 |
+
:param mesh - trimesh mesh object
|
| 9 |
+
returns bb_min (3), bb_max (3)
|
| 10 |
+
"""
|
| 11 |
+
bb_min = mesh.vertices.max(axis=0)
|
| 12 |
+
bb_max = mesh.vertices.min(axis=0)
|
| 13 |
+
return bb_min, bb_max
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_scene_bb(meshes):
|
| 17 |
+
"""
|
| 18 |
+
:param mesh_seqs - (potentially nested) list of trimesh objects
|
| 19 |
+
returns bb_min (3), bb_max (3)
|
| 20 |
+
"""
|
| 21 |
+
if isinstance(meshes, trimesh.Trimesh):
|
| 22 |
+
return get_mesh_bb(meshes)
|
| 23 |
+
|
| 24 |
+
bb_mins, bb_maxs = zip(*[get_scene_bb(mesh) for mesh in meshes])
|
| 25 |
+
bb_mins = np.stack(bb_mins, axis=0)
|
| 26 |
+
bb_maxs = np.stack(bb_maxs, axis=0)
|
| 27 |
+
return bb_mins.min(axis=0), bb_maxs.max(axis=0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def make_batch_mesh(verts, faces, colors):
|
| 31 |
+
"""
|
| 32 |
+
convenience function to make batch of meshes
|
| 33 |
+
meshs have same faces in batch, verts have same color in mesh
|
| 34 |
+
:param verts (B, V, 3)
|
| 35 |
+
:param faces (F, 3)
|
| 36 |
+
:param colors (B, 3)
|
| 37 |
+
"""
|
| 38 |
+
B, V, _ = verts.shape
|
| 39 |
+
return [make_mesh(verts[b], faces, colors[b, None].expand(V, -1)) for b in range(B)]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_mesh(verts, faces, colors=None, yup=True):
|
| 43 |
+
"""
|
| 44 |
+
create a trimesh object for the faces and vertices
|
| 45 |
+
:param verts (V, 3) tensor
|
| 46 |
+
:param faces (F, 3) tensor
|
| 47 |
+
:param colors (optional) (V, 3) tensor
|
| 48 |
+
:param yup (optional bool) whether or not to save with Y up
|
| 49 |
+
"""
|
| 50 |
+
verts = verts.detach().cpu().numpy()
|
| 51 |
+
faces = faces.detach().cpu().numpy()
|
| 52 |
+
if yup:
|
| 53 |
+
verts = np.array([1, -1, -1])[None, :] * verts
|
| 54 |
+
if colors is None:
|
| 55 |
+
colors = np.ones_like(verts) * 0.5
|
| 56 |
+
else:
|
| 57 |
+
colors = colors.detach().cpu().numpy()
|
| 58 |
+
return trimesh.Trimesh(
|
| 59 |
+
vertices=verts, faces=faces, vertex_colors=colors, process=False
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def save_mesh_scenes(out_dir, scenes):
|
| 64 |
+
"""
|
| 65 |
+
:param scenes, list of scenes (list of meshes)
|
| 66 |
+
"""
|
| 67 |
+
assert isinstance(scenes, list)
|
| 68 |
+
assert isinstance(scenes[0], list)
|
| 69 |
+
B = len(scenes[0])
|
| 70 |
+
if B == 1:
|
| 71 |
+
save_meshes_to_obj(out_dir, [x[0] for x in scenes])
|
| 72 |
+
else:
|
| 73 |
+
save_scenes_to_glb(out_dir, scenes)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def save_scenes_to_glb(out_dir, scenes):
|
| 77 |
+
"""
|
| 78 |
+
Saves a list of scenes (list of meshes) each to glb files
|
| 79 |
+
"""
|
| 80 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 81 |
+
for t, meshes in enumerate(scenes):
|
| 82 |
+
save_meshes_to_glb(f"{out_dir}/scene_{t:03d}.glb", meshes)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_meshes_to_glb(path, meshes, names=None):
|
| 86 |
+
"""
|
| 87 |
+
put trimesh meshes in a scene and export to glb
|
| 88 |
+
"""
|
| 89 |
+
if names is not None:
|
| 90 |
+
assert len(meshes) == len(names)
|
| 91 |
+
|
| 92 |
+
scene = trimesh.Scene()
|
| 93 |
+
for i, mesh in enumerate(meshes):
|
| 94 |
+
name = f"mesh_{i:03d}" if names is None else names[i]
|
| 95 |
+
scene.add_geometry(mesh, node_name=name)
|
| 96 |
+
|
| 97 |
+
with open(path, "wb") as f:
|
| 98 |
+
f.write(trimesh.exchange.gltf.export_glb(scene, include_normals=True))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def save_meshes_to_obj(out_dir, meshes, names=None):
|
| 102 |
+
if names is not None:
|
| 103 |
+
assert len(meshes) == len(names)
|
| 104 |
+
|
| 105 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 106 |
+
for i, mesh in enumerate(meshes):
|
| 107 |
+
name = f"mesh_{i:03d}" if names is None else names[i]
|
| 108 |
+
path = os.path.join(out_dir, f"{name}.obj")
|
| 109 |
+
with open(path, "w") as f:
|
| 110 |
+
mesh.export(f, file_type="obj", include_color=False, include_normals=True)
|
slahmr/slahmr/geometry/pcl.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def read_pcl_tensor(path):
|
| 6 |
+
pcl_np = read_pcl(path)
|
| 7 |
+
return torch.from_numpy(pcl_np)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def align_pcl(Y, X, weight=None, fixed_scale=False):
|
| 11 |
+
"""align similarity transform to align X with Y using umeyama method
|
| 12 |
+
X' = s * R * X + t is aligned with Y
|
| 13 |
+
:param Y (*, N, 3) first trajectory
|
| 14 |
+
:param X (*, N, 3) second trajectory
|
| 15 |
+
:param weight (*, N, 1) optional weight of valid correspondences
|
| 16 |
+
:returns s (*, 1), R (*, 3, 3), t (*, 3)
|
| 17 |
+
"""
|
| 18 |
+
*dims, N, _ = Y.shape
|
| 19 |
+
N = torch.ones(*dims, 1, 1) * N
|
| 20 |
+
|
| 21 |
+
if weight is not None:
|
| 22 |
+
Y = Y * weight
|
| 23 |
+
X = X * weight
|
| 24 |
+
N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
|
| 25 |
+
|
| 26 |
+
# subtract mean
|
| 27 |
+
my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
|
| 28 |
+
mx = X.sum(dim=-2) / N[..., 0]
|
| 29 |
+
y0 = Y - my[..., None, :] # (*, N, 3)
|
| 30 |
+
x0 = X - mx[..., None, :]
|
| 31 |
+
|
| 32 |
+
if weight is not None:
|
| 33 |
+
y0 = y0 * weight
|
| 34 |
+
x0 = x0 * weight
|
| 35 |
+
|
| 36 |
+
# correlation
|
| 37 |
+
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
|
| 38 |
+
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
|
| 39 |
+
|
| 40 |
+
S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
|
| 41 |
+
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
|
| 42 |
+
S[neg, 2, 2] = -1
|
| 43 |
+
|
| 44 |
+
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
|
| 45 |
+
|
| 46 |
+
D = torch.diag_embed(D) # (*, 3, 3)
|
| 47 |
+
if fixed_scale:
|
| 48 |
+
s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
|
| 49 |
+
else:
|
| 50 |
+
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
|
| 51 |
+
s = (
|
| 52 |
+
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
|
| 53 |
+
dim=-1, keepdim=True
|
| 54 |
+
)
|
| 55 |
+
/ var[..., 0]
|
| 56 |
+
) # (*, 1)
|
| 57 |
+
|
| 58 |
+
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
|
| 59 |
+
|
| 60 |
+
return s, R, t
|
slahmr/slahmr/geometry/plane.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def fit_plane(points):
|
| 5 |
+
"""
|
| 6 |
+
:param points (*, N, 3)
|
| 7 |
+
returns (*, 3) plane parameters (returns in normal * offset format)
|
| 8 |
+
"""
|
| 9 |
+
*dims, N, D = points.shape
|
| 10 |
+
mean = points.mean(dim=-2, keepdim=True)
|
| 11 |
+
# (*, N, D), (*, D), (*, D, D)
|
| 12 |
+
U, S, Vh = torch.linalg.svd(points - mean)
|
| 13 |
+
normal = Vh[..., -1, :] # (*, D)
|
| 14 |
+
offset = torch.einsum("...ij,...j->...i", points, normal) # (*, N)
|
| 15 |
+
offset = offset.mean(dim=-1, keepdim=True)
|
| 16 |
+
return torch.cat([normal, offset], dim=-1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_plane_transform(up, ground_plane=None, xyz_orig=None):
|
| 20 |
+
"""
|
| 21 |
+
get R, t rigid transform from plane and desired origin
|
| 22 |
+
:param up (3,) up vector of coordinate frame
|
| 23 |
+
:param ground_plane (4) (a, b, c, d) where a,b,c is the normal
|
| 24 |
+
:param xyz_orig (3) desired origin
|
| 25 |
+
"""
|
| 26 |
+
R = torch.eye(3)
|
| 27 |
+
t = torch.zeros(3)
|
| 28 |
+
if ground_plane is None:
|
| 29 |
+
return R, t
|
| 30 |
+
|
| 31 |
+
# compute transform between world up vector and passed in floor
|
| 32 |
+
ground_plane = torch.as_tensor(ground_plane)
|
| 33 |
+
ground_plane = torch.sign(ground_plane[3]) * ground_plane
|
| 34 |
+
|
| 35 |
+
normal = ground_plane[:3]
|
| 36 |
+
normal = normal / torch.linalg.norm(normal)
|
| 37 |
+
v = torch.linalg.cross(up, normal)
|
| 38 |
+
ang_sin = torch.linalg.norm(v)
|
| 39 |
+
ang_cos = up.dot(normal)
|
| 40 |
+
skew_v = torch.as_tensor([[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]])
|
| 41 |
+
R = torch.eye(3) + skew_v + (skew_v @ skew_v) * ((1.0 - ang_cos) / (ang_sin**2))
|
| 42 |
+
|
| 43 |
+
# project origin onto plane
|
| 44 |
+
if xyz_orig is None:
|
| 45 |
+
xyz_orig = torch.zeros(3)
|
| 46 |
+
t, _ = compute_plane_intersection(xyz_orig, -normal, ground_plane)
|
| 47 |
+
|
| 48 |
+
return R, t
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_floor_plane(floor_plane):
|
| 52 |
+
"""
|
| 53 |
+
Takes floor plane in the optimization form (Bx3 with a,b,c * d) and parses into
|
| 54 |
+
(a,b,c,d) from with (a,b,c) normal facing "up in the camera frame and d the offset.
|
| 55 |
+
"""
|
| 56 |
+
floor_offset = torch.norm(floor_plane, dim=-1, keepdim=True)
|
| 57 |
+
floor_normal = floor_plane / floor_offset
|
| 58 |
+
|
| 59 |
+
# in camera system -y is up, so floor plane normal y component should never be positive
|
| 60 |
+
# (assuming the camera is not sideways or upside down)
|
| 61 |
+
neg_mask = floor_normal[..., 1:2] > 0.0
|
| 62 |
+
floor_normal = torch.where(
|
| 63 |
+
neg_mask.expand_as(floor_normal), -floor_normal, floor_normal
|
| 64 |
+
)
|
| 65 |
+
floor_offset = torch.where(neg_mask, -floor_offset, floor_offset)
|
| 66 |
+
floor_plane_4d = torch.cat([floor_normal, floor_offset], dim=-1)
|
| 67 |
+
|
| 68 |
+
return floor_plane_4d
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def compute_plane_intersection(point, direction, plane):
|
| 72 |
+
"""
|
| 73 |
+
Given a ray defined by a point in space and a direction,
|
| 74 |
+
compute the intersection point with the given plane.
|
| 75 |
+
Detect intersection in either direction or -direction.
|
| 76 |
+
Note, ray may not actually intersect with the plane.
|
| 77 |
+
|
| 78 |
+
Returns the intersection point and s where
|
| 79 |
+
point + s * direction = intersection_point. if s < 0 it means
|
| 80 |
+
-direction intersects.
|
| 81 |
+
|
| 82 |
+
- point : B x 3
|
| 83 |
+
- direction : B x 3
|
| 84 |
+
- plane : B x 4 (a, b, c, d) where (a, b, c) is the normal and (d) the offset.
|
| 85 |
+
"""
|
| 86 |
+
dims = point.shape[:-1]
|
| 87 |
+
plane_normal = plane[..., :3]
|
| 88 |
+
plane_off = plane[..., 3]
|
| 89 |
+
s = (plane_off - bdot(plane_normal, point)) / (bdot(plane_normal, direction) + 1e-4)
|
| 90 |
+
itsct_pt = point + s.reshape((-1, 1)) * direction
|
| 91 |
+
return itsct_pt, s
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def bdot(A1, A2, keepdim=False):
|
| 95 |
+
"""
|
| 96 |
+
Batched dot product.
|
| 97 |
+
- A1 : B x D
|
| 98 |
+
- A2 : B x D.
|
| 99 |
+
Returns B.
|
| 100 |
+
"""
|
| 101 |
+
return (A1 * A2).sum(dim=-1, keepdim=keepdim)
|
slahmr/slahmr/geometry/rotation.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
| 7 |
+
"""
|
| 8 |
+
Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
|
| 9 |
+
Calculates the rotation matrices for a batch of rotation vectors
|
| 10 |
+
- param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
|
| 11 |
+
- returns R: torch.tensor (N, 3, 3) rotation matrices
|
| 12 |
+
"""
|
| 13 |
+
batch_size = rot_vecs.shape[0]
|
| 14 |
+
device = rot_vecs.device
|
| 15 |
+
|
| 16 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
| 17 |
+
rot_dir = rot_vecs / angle
|
| 18 |
+
|
| 19 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
| 20 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
| 21 |
+
|
| 22 |
+
# Bx1 arrays
|
| 23 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
| 24 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
| 25 |
+
|
| 26 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
| 27 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
|
| 28 |
+
(batch_size, 3, 3)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
| 32 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
| 33 |
+
return rot_mat
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def quaternion_mul(q0, q1):
|
| 37 |
+
"""
|
| 38 |
+
EXPECTS WXYZ
|
| 39 |
+
:param q0 (*, 4)
|
| 40 |
+
:param q1 (*, 4)
|
| 41 |
+
"""
|
| 42 |
+
r0, r1 = q0[..., :1], q1[..., :1]
|
| 43 |
+
v0, v1 = q0[..., 1:], q1[..., 1:]
|
| 44 |
+
r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
|
| 45 |
+
v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
|
| 46 |
+
return torch.cat([r, v], dim=-1)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def quaternion_inverse(q, eps=1e-8):
|
| 50 |
+
"""
|
| 51 |
+
EXPECTS WXYZ
|
| 52 |
+
:param q (*, 4)
|
| 53 |
+
"""
|
| 54 |
+
conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
|
| 55 |
+
mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
|
| 56 |
+
return conj / mag
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def quaternion_slerp(t, q0, q1, eps=1e-8):
|
| 60 |
+
"""
|
| 61 |
+
:param t (*, 1) must be between 0 and 1
|
| 62 |
+
:param q0 (*, 4)
|
| 63 |
+
:param q1 (*, 4)
|
| 64 |
+
"""
|
| 65 |
+
dims = q0.shape[:-1]
|
| 66 |
+
t = t.view(*dims, 1)
|
| 67 |
+
|
| 68 |
+
q0 = F.normalize(q0, p=2, dim=-1)
|
| 69 |
+
q1 = F.normalize(q1, p=2, dim=-1)
|
| 70 |
+
dot = (q0 * q1).sum(dim=-1, keepdim=True)
|
| 71 |
+
|
| 72 |
+
# make sure we give the shortest rotation path (< 180d)
|
| 73 |
+
neg = dot < 0
|
| 74 |
+
q1 = torch.where(neg, -q1, q1)
|
| 75 |
+
dot = torch.where(neg, -dot, dot)
|
| 76 |
+
angle = torch.acos(dot)
|
| 77 |
+
|
| 78 |
+
# if angle is too small, just do linear interpolation
|
| 79 |
+
collin = torch.abs(dot) > 1 - eps
|
| 80 |
+
fac = 1 / torch.sin(angle)
|
| 81 |
+
w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
|
| 82 |
+
w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
|
| 83 |
+
slerp = q0 * w0 + q1 * w1
|
| 84 |
+
return slerp
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
| 88 |
+
"""
|
| 89 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 90 |
+
|
| 91 |
+
Convert rotation matrix to Rodrigues vector
|
| 92 |
+
"""
|
| 93 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
| 94 |
+
aa = quaternion_to_angle_axis(quaternion)
|
| 95 |
+
aa[torch.isnan(aa)] = 0.0
|
| 96 |
+
return aa
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def quaternion_to_angle_axis(quaternion):
|
| 100 |
+
"""
|
| 101 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 102 |
+
|
| 103 |
+
Convert quaternion vector to angle axis of rotation.
|
| 104 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
| 105 |
+
|
| 106 |
+
:param quaternion (*, 4) expects WXYZ
|
| 107 |
+
:returns angle_axis (*, 3)
|
| 108 |
+
"""
|
| 109 |
+
# unpack input and compute conversion
|
| 110 |
+
q1 = quaternion[..., 1]
|
| 111 |
+
q2 = quaternion[..., 2]
|
| 112 |
+
q3 = quaternion[..., 3]
|
| 113 |
+
sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
|
| 114 |
+
|
| 115 |
+
sin_theta = torch.sqrt(sin_squared_theta)
|
| 116 |
+
cos_theta = quaternion[..., 0]
|
| 117 |
+
two_theta = 2.0 * torch.where(
|
| 118 |
+
cos_theta < 0.0,
|
| 119 |
+
torch.atan2(-sin_theta, -cos_theta),
|
| 120 |
+
torch.atan2(sin_theta, cos_theta),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
k_pos = two_theta / sin_theta
|
| 124 |
+
k_neg = 2.0 * torch.ones_like(sin_theta)
|
| 125 |
+
k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
| 126 |
+
|
| 127 |
+
angle_axis = torch.zeros_like(quaternion)[..., :3]
|
| 128 |
+
angle_axis[..., 0] += q1 * k
|
| 129 |
+
angle_axis[..., 1] += q2 * k
|
| 130 |
+
angle_axis[..., 2] += q3 * k
|
| 131 |
+
return angle_axis
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def quaternion_to_rotation_matrix(quaternion):
|
| 135 |
+
"""
|
| 136 |
+
Convert a quaternion to a rotation matrix.
|
| 137 |
+
Taken from https://github.com/kornia/kornia, based on
|
| 138 |
+
https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
|
| 139 |
+
https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
|
| 140 |
+
:param quaternion (N, 4) expects WXYZ order
|
| 141 |
+
returns rotation matrix (N, 3, 3)
|
| 142 |
+
"""
|
| 143 |
+
# normalize the input quaternion
|
| 144 |
+
quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
|
| 145 |
+
*dims, _ = quaternion_norm.shape
|
| 146 |
+
|
| 147 |
+
# unpack the normalized quaternion components
|
| 148 |
+
w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
|
| 149 |
+
|
| 150 |
+
# compute the actual conversion
|
| 151 |
+
tx = 2.0 * x
|
| 152 |
+
ty = 2.0 * y
|
| 153 |
+
tz = 2.0 * z
|
| 154 |
+
twx = tx * w
|
| 155 |
+
twy = ty * w
|
| 156 |
+
twz = tz * w
|
| 157 |
+
txx = tx * x
|
| 158 |
+
txy = ty * x
|
| 159 |
+
txz = tz * x
|
| 160 |
+
tyy = ty * y
|
| 161 |
+
tyz = tz * y
|
| 162 |
+
tzz = tz * z
|
| 163 |
+
one = torch.tensor(1.0)
|
| 164 |
+
|
| 165 |
+
matrix = torch.stack(
|
| 166 |
+
(
|
| 167 |
+
one - (tyy + tzz),
|
| 168 |
+
txy - twz,
|
| 169 |
+
txz + twy,
|
| 170 |
+
txy + twz,
|
| 171 |
+
one - (txx + tzz),
|
| 172 |
+
tyz - twx,
|
| 173 |
+
txz - twy,
|
| 174 |
+
tyz + twx,
|
| 175 |
+
one - (txx + tyy),
|
| 176 |
+
),
|
| 177 |
+
dim=-1,
|
| 178 |
+
).view(*dims, 3, 3)
|
| 179 |
+
return matrix
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def angle_axis_to_quaternion(angle_axis):
|
| 183 |
+
"""
|
| 184 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 185 |
+
Convert angle axis to quaternion in WXYZ order
|
| 186 |
+
:param angle_axis (*, 3)
|
| 187 |
+
:returns quaternion (*, 4) WXYZ order
|
| 188 |
+
"""
|
| 189 |
+
theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
|
| 190 |
+
# need to handle the zero rotation case
|
| 191 |
+
valid = theta_sq > 0
|
| 192 |
+
theta = torch.sqrt(theta_sq)
|
| 193 |
+
half_theta = 0.5 * theta
|
| 194 |
+
ones = torch.ones_like(half_theta)
|
| 195 |
+
# fill zero with the limit of sin ax / x -> a
|
| 196 |
+
k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
|
| 197 |
+
w = torch.where(valid, torch.cos(half_theta), ones)
|
| 198 |
+
quat = torch.cat([w, k * angle_axis], dim=-1)
|
| 199 |
+
return quat
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
| 203 |
+
"""
|
| 204 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 205 |
+
Convert rotation matrix to 4d quaternion vector
|
| 206 |
+
This algorithm is based on algorithm described in
|
| 207 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
| 208 |
+
|
| 209 |
+
:param rotation_matrix (N, 3, 3)
|
| 210 |
+
"""
|
| 211 |
+
*dims, m, n = rotation_matrix.shape
|
| 212 |
+
rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
|
| 213 |
+
|
| 214 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
| 215 |
+
|
| 216 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
| 217 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
| 218 |
+
|
| 219 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 220 |
+
q0 = torch.stack(
|
| 221 |
+
[
|
| 222 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
| 223 |
+
t0,
|
| 224 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
| 225 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
| 226 |
+
],
|
| 227 |
+
-1,
|
| 228 |
+
)
|
| 229 |
+
t0_rep = t0.repeat(4, 1).t()
|
| 230 |
+
|
| 231 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 232 |
+
q1 = torch.stack(
|
| 233 |
+
[
|
| 234 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
| 235 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
| 236 |
+
t1,
|
| 237 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
| 238 |
+
],
|
| 239 |
+
-1,
|
| 240 |
+
)
|
| 241 |
+
t1_rep = t1.repeat(4, 1).t()
|
| 242 |
+
|
| 243 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 244 |
+
q2 = torch.stack(
|
| 245 |
+
[
|
| 246 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
| 247 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
| 248 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
| 249 |
+
t2,
|
| 250 |
+
],
|
| 251 |
+
-1,
|
| 252 |
+
)
|
| 253 |
+
t2_rep = t2.repeat(4, 1).t()
|
| 254 |
+
|
| 255 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 256 |
+
q3 = torch.stack(
|
| 257 |
+
[
|
| 258 |
+
t3,
|
| 259 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
| 260 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
| 261 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
| 262 |
+
],
|
| 263 |
+
-1,
|
| 264 |
+
)
|
| 265 |
+
t3_rep = t3.repeat(4, 1).t()
|
| 266 |
+
|
| 267 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
| 268 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
| 269 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
| 270 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
| 271 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
| 272 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
| 273 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
| 274 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
| 275 |
+
|
| 276 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
| 277 |
+
q /= torch.sqrt(
|
| 278 |
+
t0_rep * mask_c0
|
| 279 |
+
+ t1_rep * mask_c1
|
| 280 |
+
+ t2_rep * mask_c2 # noqa
|
| 281 |
+
+ t3_rep * mask_c3
|
| 282 |
+
) # noqa
|
| 283 |
+
q *= 0.5
|
| 284 |
+
return q.reshape(*dims, 4)
|
slahmr/slahmr/humor/__init__.py
ADDED
|
File without changes
|
slahmr/slahmr/humor/amass_utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Taken from https://github.com/davrempe/humor
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from body_model.utils import SMPL_JOINTS
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
TRAIN_DATASETS = [
|
| 9 |
+
"CMU",
|
| 10 |
+
"MPI_Limits",
|
| 11 |
+
"TotalCapture",
|
| 12 |
+
"Eyes_Japan_Dataset",
|
| 13 |
+
"KIT",
|
| 14 |
+
"BioMotionLab_NTroje",
|
| 15 |
+
"BMLmovi",
|
| 16 |
+
"EKUT",
|
| 17 |
+
"ACCAD",
|
| 18 |
+
]
|
| 19 |
+
TEST_DATASETS = ["Transitions_mocap", "HumanEva"]
|
| 20 |
+
VAL_DATASETS = ["MPI_HDM05", "SFU", "MPI_mosh"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
SPLITS = ["train", "val", "test", "custom"]
|
| 24 |
+
SPLIT_BY = [
|
| 25 |
+
"single", # the data path is a single .npz file. Don't split: train and test are same
|
| 26 |
+
"sequence", # the data paths are directories of subjects. Collate and split by sequence.
|
| 27 |
+
"subject", # the data paths are directories of datasets. Collate and split by subject.
|
| 28 |
+
"dataset", # a single data path to the amass data root is given. The predefined datasets will be used for each split.
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
ROT_REPS = ["mat", "aa", "6d"]
|
| 32 |
+
|
| 33 |
+
# these correspond to [root, left knee, right knee, left heel, right heel, left toe, right toe, left hand, right hand]
|
| 34 |
+
CONTACT_ORDERING = [
|
| 35 |
+
"hips",
|
| 36 |
+
"leftLeg",
|
| 37 |
+
"rightLeg",
|
| 38 |
+
"leftFoot",
|
| 39 |
+
"rightFoot",
|
| 40 |
+
"leftToeBase",
|
| 41 |
+
"rightToeBase",
|
| 42 |
+
"leftHand",
|
| 43 |
+
"rightHand",
|
| 44 |
+
]
|
| 45 |
+
CONTACT_INDS = [SMPL_JOINTS[jname] for jname in CONTACT_ORDERING]
|
| 46 |
+
|
| 47 |
+
NUM_BODY_JOINTS = len(SMPL_JOINTS) - 1
|
| 48 |
+
NUM_KEYPT_VERTS = 43
|
| 49 |
+
|
| 50 |
+
DATA_NAMES = [
|
| 51 |
+
"trans",
|
| 52 |
+
"trans_vel",
|
| 53 |
+
"root_orient",
|
| 54 |
+
"root_orient_vel",
|
| 55 |
+
"pose_body",
|
| 56 |
+
"pose_body_vel",
|
| 57 |
+
"joints",
|
| 58 |
+
"joints_vel",
|
| 59 |
+
"joints_orient_vel",
|
| 60 |
+
"verts",
|
| 61 |
+
"verts_vel",
|
| 62 |
+
"contacts",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
SMPL_JOINTS_RETURN_CONFIG = {
|
| 66 |
+
"trans": True,
|
| 67 |
+
"trans_vel": True,
|
| 68 |
+
"root_orient": True,
|
| 69 |
+
"root_orient_vel": True,
|
| 70 |
+
"pose_body": True,
|
| 71 |
+
"pose_body_vel": False,
|
| 72 |
+
"joints": True,
|
| 73 |
+
"joints_vel": True,
|
| 74 |
+
"joints_orient_vel": False,
|
| 75 |
+
"verts": False,
|
| 76 |
+
"verts_vel": False,
|
| 77 |
+
"contacts": False,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
SMPL_JOINTS_CONTACTS_RETURN_CONFIG = {
|
| 81 |
+
"trans": True,
|
| 82 |
+
"trans_vel": True,
|
| 83 |
+
"root_orient": True,
|
| 84 |
+
"root_orient_vel": True,
|
| 85 |
+
"pose_body": True,
|
| 86 |
+
"pose_body_vel": False,
|
| 87 |
+
"joints": True,
|
| 88 |
+
"joints_vel": True,
|
| 89 |
+
"joints_orient_vel": False,
|
| 90 |
+
"verts": False,
|
| 91 |
+
"verts_vel": False,
|
| 92 |
+
"contacts": True,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
ALL_RETURN_CONFIG = {
|
| 96 |
+
"trans": True,
|
| 97 |
+
"trans_vel": True,
|
| 98 |
+
"root_orient": True,
|
| 99 |
+
"root_orient_vel": True,
|
| 100 |
+
"pose_body": True,
|
| 101 |
+
"pose_body_vel": False,
|
| 102 |
+
"joints": True,
|
| 103 |
+
"joints_vel": True,
|
| 104 |
+
"joints_orient_vel": False,
|
| 105 |
+
"verts": True,
|
| 106 |
+
"verts_vel": False,
|
| 107 |
+
"contacts": True,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
RETURN_CONFIGS = {
|
| 111 |
+
"smpl+joints+contacts": SMPL_JOINTS_CONTACTS_RETURN_CONFIG,
|
| 112 |
+
"smpl+joints": SMPL_JOINTS_RETURN_CONFIG,
|
| 113 |
+
"all": ALL_RETURN_CONFIG,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def data_name_list(return_config):
|
| 118 |
+
"""
|
| 119 |
+
returns the list of data values in the given configuration
|
| 120 |
+
"""
|
| 121 |
+
cur_ret_cfg = RETURN_CONFIGS[return_config]
|
| 122 |
+
data_names = [k for k in DATA_NAMES if cur_ret_cfg[k]]
|
| 123 |
+
return data_names
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def data_dim(dname, rot_rep_size=9):
|
| 127 |
+
"""
|
| 128 |
+
returns the dimension of the data with the given name. If the data is a rotation, returns the size with the given representation.
|
| 129 |
+
"""
|
| 130 |
+
if dname in ["trans", "trans_vel", "root_orient_vel"]:
|
| 131 |
+
return 3
|
| 132 |
+
elif dname in ["root_orient"]:
|
| 133 |
+
return rot_rep_size
|
| 134 |
+
elif dname in ["pose_body"]:
|
| 135 |
+
return NUM_BODY_JOINTS * rot_rep_size
|
| 136 |
+
elif dname in ["pose_body_vel"]:
|
| 137 |
+
return NUM_BODY_JOINTS * 3
|
| 138 |
+
elif dname in ["joints", "joints_vel"]:
|
| 139 |
+
return len(SMPL_JOINTS) * 3
|
| 140 |
+
elif dname in ["joints_orient_vel"]:
|
| 141 |
+
return 1
|
| 142 |
+
elif dname in ["verts", "verts_vel"]:
|
| 143 |
+
return NUM_KEYPT_VERTS * 3
|
| 144 |
+
elif dname in ["contacts"]:
|
| 145 |
+
return len(CONTACT_ORDERING)
|
| 146 |
+
else:
|
| 147 |
+
print("The given data name %s is not valid!" % (dname))
|
| 148 |
+
exit()
|
slahmr/slahmr/humor/humor_model.py
ADDED
|
@@ -0,0 +1,1655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Taken from https://github.com/davrempe/humor
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time, os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.distributions.normal import Normal
|
| 11 |
+
|
| 12 |
+
from .amass_utils import data_name_list, data_dim
|
| 13 |
+
from .transforms import (
|
| 14 |
+
convert_to_rotmat,
|
| 15 |
+
compute_world2aligned_mat,
|
| 16 |
+
rotation_matrix_to_angle_axis,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from body_model.specs import SMPL_JOINTS, SMPLH_PATH
|
| 20 |
+
from body_model.body_model import BodyModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
IN_ROT_REPS = ["aa", "6d", "mat"]
|
| 24 |
+
OUT_ROT_REPS = ["aa", "6d", "9d"]
|
| 25 |
+
ROT_REP_SIZE = {"aa": 3, "6d": 6, "mat": 9, "9d": 9}
|
| 26 |
+
NUM_SMPL_JOINTS = len(SMPL_JOINTS)
|
| 27 |
+
NUM_BODY_JOINTS = NUM_SMPL_JOINTS - 1 # no root
|
| 28 |
+
BETA_SIZE = 16
|
| 29 |
+
|
| 30 |
+
POSTERIOR_OPTIONS = ["mlp"]
|
| 31 |
+
PRIOR_OPTIONS = ["mlp"]
|
| 32 |
+
DECODER_OPTIONS = ["mlp"]
|
| 33 |
+
|
| 34 |
+
WORLD2ALIGN_NAME_CACHE = {
|
| 35 |
+
"root_orient": None,
|
| 36 |
+
"trans": None,
|
| 37 |
+
"joints": None,
|
| 38 |
+
"verts": None,
|
| 39 |
+
"joints_vel": None,
|
| 40 |
+
"verts_vel": None,
|
| 41 |
+
"trans_vel": None,
|
| 42 |
+
"root_orient_vel": None,
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def step(
|
| 47 |
+
model, loss_func, data, dataset, device, cur_epoch, mode="train", use_gt_p=1.0
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Given data for the current training step (batch),
|
| 51 |
+
pulls out the necessary needed data,
|
| 52 |
+
runs the model,
|
| 53 |
+
calculates and returns the loss.
|
| 54 |
+
|
| 55 |
+
- use_gt_p : the probability of using ground truth as input to each step rather than the model's own prediction
|
| 56 |
+
(1.0 is fully supervised, 0.0 is fully autoregressive)
|
| 57 |
+
"""
|
| 58 |
+
use_sched_samp = use_gt_p < 1.0
|
| 59 |
+
batch_in, batch_out, meta = data
|
| 60 |
+
|
| 61 |
+
prep_data = model.prepare_input(
|
| 62 |
+
batch_in,
|
| 63 |
+
device,
|
| 64 |
+
data_out=batch_out,
|
| 65 |
+
return_input_dict=True,
|
| 66 |
+
return_global_dict=use_sched_samp,
|
| 67 |
+
)
|
| 68 |
+
if use_sched_samp:
|
| 69 |
+
x_past, x_t, gt_dict, input_dict, global_gt_dict = prep_data
|
| 70 |
+
else:
|
| 71 |
+
x_past, x_t, gt_dict, input_dict = prep_data
|
| 72 |
+
|
| 73 |
+
B, T, S_in, _ = x_past.size()
|
| 74 |
+
S_out = x_t.size(2)
|
| 75 |
+
|
| 76 |
+
if not use_sched_samp:
|
| 77 |
+
# fully supervised phase
|
| 78 |
+
# start by using gt at every step, so just form all steps from all sequences into one large batch
|
| 79 |
+
# and get per-step predictions
|
| 80 |
+
x_past_batched = x_past.reshape((B * T, S_in, -1))
|
| 81 |
+
x_t_batched = x_t.reshape((B * T, S_out, -1))
|
| 82 |
+
out_dict = model(x_past_batched, x_t_batched)
|
| 83 |
+
else:
|
| 84 |
+
# in scheduled sampling or fully autoregressive phase
|
| 85 |
+
init_input_dict = dict()
|
| 86 |
+
for k in input_dict.keys():
|
| 87 |
+
init_input_dict[k] = input_dict[k][
|
| 88 |
+
:, 0, :, :
|
| 89 |
+
] # only need first step for init
|
| 90 |
+
# this out_dict is the global state
|
| 91 |
+
sched_samp_out = model.scheduled_sampling(
|
| 92 |
+
x_past,
|
| 93 |
+
x_t,
|
| 94 |
+
init_input_dict,
|
| 95 |
+
p=use_gt_p,
|
| 96 |
+
gender=meta["gender"],
|
| 97 |
+
betas=meta["betas"].to(device),
|
| 98 |
+
need_global_out=(not model.detach_sched_samp),
|
| 99 |
+
)
|
| 100 |
+
if model.detach_sched_samp:
|
| 101 |
+
out_dict = sched_samp_out
|
| 102 |
+
else:
|
| 103 |
+
out_dict, _ = sched_samp_out
|
| 104 |
+
# gt must be global state for supervision in this case
|
| 105 |
+
if not model.detach_sched_samp:
|
| 106 |
+
print("USING global supervision")
|
| 107 |
+
gt_dict = global_gt_dict
|
| 108 |
+
|
| 109 |
+
# loss can be computed per output step in parallel
|
| 110 |
+
# batch dicts accordingly
|
| 111 |
+
for k in out_dict.keys():
|
| 112 |
+
if k == "posterior_distrib" or k == "prior_distrib":
|
| 113 |
+
m, v = out_dict[k]
|
| 114 |
+
m = m.reshape((B * T, -1))
|
| 115 |
+
v = v.reshape((B * T, -1))
|
| 116 |
+
out_dict[k] = (m, v)
|
| 117 |
+
else:
|
| 118 |
+
out_dict[k] = out_dict[k].reshape((B * T * S_out, -1))
|
| 119 |
+
for k in gt_dict.keys():
|
| 120 |
+
gt_dict[k] = gt_dict[k].reshape((B * T * S_out, -1))
|
| 121 |
+
|
| 122 |
+
gender_in = np.broadcast_to(
|
| 123 |
+
np.array(meta["gender"]).reshape((B, 1, 1, 1)), (B, T, S_out, 1)
|
| 124 |
+
)
|
| 125 |
+
gender_in = gender_in.reshape((B * T * S_out, 1))
|
| 126 |
+
betas_in = meta["betas"].reshape((B, T, 1, -1)).expand((B, T, S_out, 16)).to(device)
|
| 127 |
+
betas_in = betas_in.reshape((B * T * S_out, 16))
|
| 128 |
+
loss, stats_dict = loss_func(
|
| 129 |
+
out_dict, gt_dict, cur_epoch, gender=gender_in, betas=betas_in
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return loss, stats_dict
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class HumorModel(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
in_rot_rep="aa",
|
| 139 |
+
out_rot_rep="aa",
|
| 140 |
+
latent_size=48,
|
| 141 |
+
steps_in=1,
|
| 142 |
+
conditional_prior=True, # use a learned prior rather than standard normal
|
| 143 |
+
output_delta=True, # output change in state from decoder rather than next step directly
|
| 144 |
+
posterior_arch="mlp",
|
| 145 |
+
decoder_arch="mlp",
|
| 146 |
+
prior_arch="mlp",
|
| 147 |
+
model_data_config="smpl+joints+contacts",
|
| 148 |
+
detach_sched_samp=True, # if true, detaches outputs of previous step so gradients don't flow through many steps
|
| 149 |
+
model_use_smpl_joint_inputs=False, # if true, uses smpl joints rather than regressed joints to input at next step (during rollout and sched samp)
|
| 150 |
+
model_smpl_batch_size=1, # if using smpl joint inputs this should be batch_size of the smpl model (aka data input to rollout)
|
| 151 |
+
):
|
| 152 |
+
super(HumorModel, self).__init__()
|
| 153 |
+
self.ignore_keys = []
|
| 154 |
+
|
| 155 |
+
self.steps_in = steps_in
|
| 156 |
+
self.steps_out = 1
|
| 157 |
+
self.out_step_size = 1
|
| 158 |
+
self.detach_sched_samp = detach_sched_samp
|
| 159 |
+
self.output_delta = output_delta
|
| 160 |
+
|
| 161 |
+
if self.steps_out > 1:
|
| 162 |
+
raise NotImplementedError("Only supported single step output currently.")
|
| 163 |
+
|
| 164 |
+
if out_rot_rep not in OUT_ROT_REPS:
|
| 165 |
+
raise Exception(
|
| 166 |
+
"Not a valid output rotation representation: %s" % (out_rot_rep)
|
| 167 |
+
)
|
| 168 |
+
if in_rot_rep not in IN_ROT_REPS:
|
| 169 |
+
raise Exception(
|
| 170 |
+
"Not a valid input rotation representation: %s" % (in_rot_rep)
|
| 171 |
+
)
|
| 172 |
+
self.out_rot_rep = out_rot_rep
|
| 173 |
+
self.in_rot_rep = in_rot_rep
|
| 174 |
+
|
| 175 |
+
if posterior_arch not in POSTERIOR_OPTIONS:
|
| 176 |
+
raise Exception("Not a valid encoder architecture: %s" % (posterior_arch))
|
| 177 |
+
if decoder_arch not in DECODER_OPTIONS:
|
| 178 |
+
raise Exception("Not a valid decoder architecture: %s" % (decoder_arch))
|
| 179 |
+
if conditional_prior and prior_arch not in PRIOR_OPTIONS:
|
| 180 |
+
raise Exception("Not a valid prior architecture: %s" % (prior_arch))
|
| 181 |
+
self.posterior_arch = posterior_arch
|
| 182 |
+
self.decoder_arch = decoder_arch
|
| 183 |
+
self.prior_arch = prior_arch
|
| 184 |
+
|
| 185 |
+
# get the list of data names for this config
|
| 186 |
+
self.data_names = data_name_list(model_data_config)
|
| 187 |
+
self.aux_in_data_names = (
|
| 188 |
+
self.aux_out_data_names
|
| 189 |
+
) = None # auxiliary data will be returned as part of the input/output dictionary, but not the actual network input/output tensor
|
| 190 |
+
self.pred_contacts = False
|
| 191 |
+
if (
|
| 192 |
+
model_data_config.find("contacts") >= 0
|
| 193 |
+
): # network is outputting contact classification as well and need to supervise, but not given as input to net.
|
| 194 |
+
self.data_names.remove("contacts")
|
| 195 |
+
self.aux_out_data_names = ["contacts"]
|
| 196 |
+
self.pred_contacts = True
|
| 197 |
+
|
| 198 |
+
self.need_trans2joint = (
|
| 199 |
+
"joints" in self.data_names or "verts" in self.data_names
|
| 200 |
+
)
|
| 201 |
+
self.model_data_config = model_data_config
|
| 202 |
+
|
| 203 |
+
self.input_rot_dim = ROT_REP_SIZE[self.in_rot_rep]
|
| 204 |
+
self.input_dim_list = [
|
| 205 |
+
data_dim(dname, rot_rep_size=self.input_rot_dim)
|
| 206 |
+
for dname in self.data_names
|
| 207 |
+
]
|
| 208 |
+
self.input_data_dim = sum(self.input_dim_list)
|
| 209 |
+
|
| 210 |
+
self.output_rot_dim = ROT_REP_SIZE[self.out_rot_rep]
|
| 211 |
+
self.output_dim_list = [
|
| 212 |
+
data_dim(dname, rot_rep_size=self.output_rot_dim)
|
| 213 |
+
for dname in self.data_names
|
| 214 |
+
]
|
| 215 |
+
self.delta_output_dim_list = [
|
| 216 |
+
data_dim(dname, rot_rep_size=ROT_REP_SIZE["mat"])
|
| 217 |
+
for dname in self.data_names
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
if self.pred_contacts:
|
| 221 |
+
# account for contact classification output
|
| 222 |
+
self.output_dim_list.append(data_dim("contacts"))
|
| 223 |
+
self.delta_output_dim_list.append(data_dim("contacts"))
|
| 224 |
+
|
| 225 |
+
self.output_data_dim = sum(self.output_dim_list)
|
| 226 |
+
|
| 227 |
+
self.latent_size = latent_size
|
| 228 |
+
past_data_dim = self.steps_in * self.input_data_dim
|
| 229 |
+
t_data_dim = self.steps_out * self.input_data_dim
|
| 230 |
+
|
| 231 |
+
# posterior encoder (given past and future, predict latent transition distribution)
|
| 232 |
+
print("Using posterior architecture: %s" % (self.posterior_arch))
|
| 233 |
+
if self.posterior_arch == "mlp":
|
| 234 |
+
layer_list = [
|
| 235 |
+
past_data_dim + t_data_dim,
|
| 236 |
+
1024,
|
| 237 |
+
1024,
|
| 238 |
+
1024,
|
| 239 |
+
1024,
|
| 240 |
+
self.latent_size * 2,
|
| 241 |
+
]
|
| 242 |
+
self.encoder = MLP(
|
| 243 |
+
layers=layer_list, # mu and sigma output
|
| 244 |
+
nonlinearity=nn.ReLU,
|
| 245 |
+
use_gn=True,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# decoder (given past and latent transition, predict future) for the immediate next step
|
| 249 |
+
print("Using decoder architecture: %s" % (self.decoder_arch))
|
| 250 |
+
decoder_input_dim = past_data_dim + self.latent_size
|
| 251 |
+
if self.decoder_arch == "mlp":
|
| 252 |
+
layer_list = [decoder_input_dim, 1024, 1024, 512, self.output_data_dim]
|
| 253 |
+
self.decoder = MLP(
|
| 254 |
+
layers=layer_list,
|
| 255 |
+
nonlinearity=nn.ReLU,
|
| 256 |
+
use_gn=True,
|
| 257 |
+
skip_input_idx=past_data_dim, # skip connect the latent to every layer
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# prior (if conditional, given past predict latent transition distribution)
|
| 261 |
+
self.use_conditional_prior = conditional_prior
|
| 262 |
+
if self.use_conditional_prior:
|
| 263 |
+
print("Using prior architecture: %s" % (self.prior_arch))
|
| 264 |
+
layer_list = [past_data_dim, 1024, 1024, 1024, 1024, self.latent_size * 2]
|
| 265 |
+
self.prior_net = MLP(
|
| 266 |
+
layers=layer_list, # mu and sigma output
|
| 267 |
+
nonlinearity=nn.ReLU,
|
| 268 |
+
use_gn=True,
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
print("Using standard normal prior.")
|
| 272 |
+
|
| 273 |
+
self.use_smpl_joint_inputs = model_use_smpl_joint_inputs
|
| 274 |
+
self.smpl_batch_size = model_smpl_batch_size
|
| 275 |
+
if self.use_smpl_joint_inputs:
|
| 276 |
+
# need a body model to compute the joints after each step.
|
| 277 |
+
print(
|
| 278 |
+
"Using SMPL joints rather than regressed joints as input at each step for roll out and scheduled sampling..."
|
| 279 |
+
)
|
| 280 |
+
male_bm_path = os.path.join(SMPLH_PATH, "male/model.npz")
|
| 281 |
+
self.male_bm = BodyModel(
|
| 282 |
+
bm_path=male_bm_path, num_betas=16, batch_size=self.smpl_batch_size
|
| 283 |
+
)
|
| 284 |
+
female_bm_path = os.path.join(SMPLH_PATH, "female/model.npz")
|
| 285 |
+
self.female_bm = BodyModel(
|
| 286 |
+
bm_path=female_bm_path, num_betas=16, batch_size=self.smpl_batch_size
|
| 287 |
+
)
|
| 288 |
+
neutral_bm_path = os.path.join(SMPLH_PATH, "neutral/model.npz")
|
| 289 |
+
self.neutral_bm = BodyModel(
|
| 290 |
+
bm_path=neutral_bm_path, num_betas=16, batch_size=self.smpl_batch_size
|
| 291 |
+
)
|
| 292 |
+
self.bm_dict = {
|
| 293 |
+
"male": self.male_bm,
|
| 294 |
+
"female": self.female_bm,
|
| 295 |
+
"neutral": self.neutral_bm,
|
| 296 |
+
}
|
| 297 |
+
for p in self.male_bm.parameters():
|
| 298 |
+
p.requires_grad = False
|
| 299 |
+
for p in self.female_bm.parameters():
|
| 300 |
+
p.requires_grad = False
|
| 301 |
+
for p in self.neutral_bm.parameters():
|
| 302 |
+
p.requires_grad = False
|
| 303 |
+
self.ignore_keys = ["male_bm", "female_bm", "neutral_bm"]
|
| 304 |
+
|
| 305 |
+
def prepare_input(
|
| 306 |
+
self,
|
| 307 |
+
data_in,
|
| 308 |
+
device,
|
| 309 |
+
data_out=None,
|
| 310 |
+
return_input_dict=False,
|
| 311 |
+
return_global_dict=False,
|
| 312 |
+
):
|
| 313 |
+
"""
|
| 314 |
+
Concatenates input and output data as expected by the model.
|
| 315 |
+
|
| 316 |
+
Also creates a dictionary of GT outputs for use in computing the loss. And optionally
|
| 317 |
+
a dictionary of inputs.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
#
|
| 321 |
+
# input data
|
| 322 |
+
#
|
| 323 |
+
in_unnorm_data_list = []
|
| 324 |
+
for k in self.data_names:
|
| 325 |
+
cur_dat = data_in[k].to(device)
|
| 326 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 327 |
+
cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_in, -1))
|
| 328 |
+
in_unnorm_data_list.append(cur_unnorm_dat)
|
| 329 |
+
x_past = torch.cat(in_unnorm_data_list, axis=3)
|
| 330 |
+
|
| 331 |
+
input_dict = None
|
| 332 |
+
if return_input_dict:
|
| 333 |
+
input_dict = {k: v for k, v in zip(self.data_names, in_unnorm_data_list)}
|
| 334 |
+
|
| 335 |
+
if self.aux_in_data_names is not None:
|
| 336 |
+
for k in self.aux_in_data_names:
|
| 337 |
+
cur_dat = data_in[k].to(device)
|
| 338 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 339 |
+
cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_in, -1))
|
| 340 |
+
input_dict[k] = cur_unnorm_dat
|
| 341 |
+
|
| 342 |
+
#
|
| 343 |
+
# output
|
| 344 |
+
#
|
| 345 |
+
if data_out is not None:
|
| 346 |
+
out_unnorm_data_list = []
|
| 347 |
+
for k in self.data_names:
|
| 348 |
+
cur_dat = data_out[k].to(device)
|
| 349 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 350 |
+
cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_out, -1))
|
| 351 |
+
out_unnorm_data_list.append(cur_unnorm_dat)
|
| 352 |
+
x_t = torch.cat(out_unnorm_data_list, axis=3)
|
| 353 |
+
gt_dict = {k: v for k, v in zip(self.data_names, out_unnorm_data_list)}
|
| 354 |
+
|
| 355 |
+
if self.aux_out_data_names is not None:
|
| 356 |
+
for k in self.aux_out_data_names:
|
| 357 |
+
cur_dat = data_out[k].to(device)
|
| 358 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 359 |
+
cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_out, -1))
|
| 360 |
+
gt_dict[k] = cur_unnorm_dat
|
| 361 |
+
|
| 362 |
+
return_list = [x_past, x_t, gt_dict]
|
| 363 |
+
if return_input_dict:
|
| 364 |
+
return_list.append(input_dict)
|
| 365 |
+
|
| 366 |
+
#
|
| 367 |
+
# global
|
| 368 |
+
#
|
| 369 |
+
if return_global_dict:
|
| 370 |
+
global_gt_dict = dict()
|
| 371 |
+
for k in self.data_names:
|
| 372 |
+
global_k = "global_" + k
|
| 373 |
+
cur_dat = data_out[global_k].to(device)
|
| 374 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 375 |
+
# expand each to have steps_out since originally they are just B x T x ... x D
|
| 376 |
+
cur_dat = cur_dat.reshape((B, T, 1, -1)).expand_as(gt_dict[k])
|
| 377 |
+
global_gt_dict[k] = cur_dat
|
| 378 |
+
|
| 379 |
+
if self.aux_out_data_names is not None:
|
| 380 |
+
for k in self.aux_out_data_names:
|
| 381 |
+
global_k = "global_" + k
|
| 382 |
+
cur_dat = data_out[global_k].to(device)
|
| 383 |
+
B, T = cur_dat.size(0), cur_dat.size(1)
|
| 384 |
+
# expand each to have steps_out since originally they are just B x T x ... x D
|
| 385 |
+
cur_dat = cur_dat.reshape((B, T, 1, -1)).expand_as(gt_dict[k])
|
| 386 |
+
global_gt_dict[k] = cur_dat
|
| 387 |
+
|
| 388 |
+
return_list.append(global_gt_dict)
|
| 389 |
+
|
| 390 |
+
return tuple(return_list)
|
| 391 |
+
|
| 392 |
+
else:
|
| 393 |
+
if return_input_dict:
|
| 394 |
+
return x_past, input_dict
|
| 395 |
+
else:
|
| 396 |
+
return x_past
|
| 397 |
+
|
| 398 |
+
def split_output(self, decoder_out, convert_rots=True):
|
| 399 |
+
"""
|
| 400 |
+
Given the output of the decoder, splits into each state component.
|
| 401 |
+
Also transform rotation representation to matrices.
|
| 402 |
+
|
| 403 |
+
Input:
|
| 404 |
+
- decoder_out (B x steps_out x D)
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
- output dict
|
| 408 |
+
"""
|
| 409 |
+
B = decoder_out.size(0)
|
| 410 |
+
decoder_out = decoder_out.reshape((B, self.steps_out, -1))
|
| 411 |
+
|
| 412 |
+
# collect outputs
|
| 413 |
+
name_list = self.data_names
|
| 414 |
+
if self.aux_out_data_names is not None:
|
| 415 |
+
name_list = name_list + self.aux_out_data_names
|
| 416 |
+
idx_list = (
|
| 417 |
+
self.delta_output_dim_list if self.output_delta else self.output_dim_list
|
| 418 |
+
)
|
| 419 |
+
out_dict = dict()
|
| 420 |
+
sidx = 0
|
| 421 |
+
for cur_name, cur_idx in zip(name_list, idx_list):
|
| 422 |
+
eidx = sidx + cur_idx
|
| 423 |
+
out_dict[cur_name] = decoder_out[:, :, sidx:eidx]
|
| 424 |
+
sidx = eidx
|
| 425 |
+
|
| 426 |
+
# transform rotations
|
| 427 |
+
if convert_rots and not self.output_delta: # output delta already gives rotmats
|
| 428 |
+
if "root_orient" in self.data_names:
|
| 429 |
+
out_dict["root_orient"] = convert_to_rotmat(
|
| 430 |
+
out_dict["root_orient"], rep=self.out_rot_rep
|
| 431 |
+
)
|
| 432 |
+
if "pose_body" in self.data_names:
|
| 433 |
+
out_dict["pose_body"] = convert_to_rotmat(
|
| 434 |
+
out_dict["pose_body"], rep=self.out_rot_rep
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
return out_dict
|
| 438 |
+
|
| 439 |
+
def forward(self, x_past, x_t):
|
| 440 |
+
"""
|
| 441 |
+
single step full forward pass. This uses the posterior for sampling, not the prior.
|
| 442 |
+
|
| 443 |
+
Input:
|
| 444 |
+
- x_past (B x steps_in x D)
|
| 445 |
+
- x_t (B x steps_out x D)
|
| 446 |
+
|
| 447 |
+
Returns dict of:
|
| 448 |
+
- x_pred (B x steps_out x D)
|
| 449 |
+
- posterior_distrib (Normal(mu, sigma))
|
| 450 |
+
- prior_distrib (Normal(mu, sigma))
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
B, _, D = x_past.size()
|
| 454 |
+
past_in = x_past.reshape((B, -1))
|
| 455 |
+
t_in = x_t.reshape((B, -1))
|
| 456 |
+
|
| 457 |
+
x_pred_dict = self.single_step(past_in, t_in)
|
| 458 |
+
|
| 459 |
+
return x_pred_dict
|
| 460 |
+
|
| 461 |
+
def single_step(self, past_in, t_in):
|
| 462 |
+
"""
|
| 463 |
+
single step that computes both prior and posterior for training. Samples from posterior
|
| 464 |
+
"""
|
| 465 |
+
B = past_in.size(0)
|
| 466 |
+
# use past and future to encode latent transition
|
| 467 |
+
qm, qv = self.posterior(past_in, t_in)
|
| 468 |
+
|
| 469 |
+
# prior
|
| 470 |
+
pm, pv = None, None
|
| 471 |
+
if self.use_conditional_prior:
|
| 472 |
+
# predict prior based on past
|
| 473 |
+
pm, pv = self.prior(past_in)
|
| 474 |
+
else:
|
| 475 |
+
# use standard normal
|
| 476 |
+
pm, pv = torch.zeros_like(qm), torch.ones_like(qv)
|
| 477 |
+
|
| 478 |
+
# sample from posterior using reparam trick
|
| 479 |
+
z = self.rsample(qm, qv)
|
| 480 |
+
|
| 481 |
+
# decode to get next step
|
| 482 |
+
decoder_out = self.decode(z, past_in)
|
| 483 |
+
decoder_out = decoder_out.reshape(
|
| 484 |
+
(B, self.steps_out, -1)
|
| 485 |
+
) # B x steps_out x D_out
|
| 486 |
+
|
| 487 |
+
# split output predictions and transform out rotations to matrices
|
| 488 |
+
x_pred_dict = self.split_output(decoder_out)
|
| 489 |
+
|
| 490 |
+
x_pred_dict["posterior_distrib"] = (qm, qv)
|
| 491 |
+
x_pred_dict["prior_distrib"] = (pm, pv)
|
| 492 |
+
|
| 493 |
+
return x_pred_dict
|
| 494 |
+
|
| 495 |
+
def prior(self, past_in):
|
| 496 |
+
"""
|
| 497 |
+
Encodes the posterior distribution using the past and future states.
|
| 498 |
+
|
| 499 |
+
Input:
|
| 500 |
+
- past_in (B x steps_in*D)
|
| 501 |
+
"""
|
| 502 |
+
prior_out = self.prior_net(past_in)
|
| 503 |
+
mean = prior_out[:, : self.latent_size]
|
| 504 |
+
logvar = prior_out[:, self.latent_size :]
|
| 505 |
+
var = torch.exp(logvar)
|
| 506 |
+
return mean, var
|
| 507 |
+
|
| 508 |
+
def posterior(self, past_in, t_in):
|
| 509 |
+
"""
|
| 510 |
+
Encodes the posterior distribution using the past and future states.
|
| 511 |
+
|
| 512 |
+
Input:
|
| 513 |
+
- past_in (B x steps_in*D)
|
| 514 |
+
- t_in (B x steps_out*D)
|
| 515 |
+
"""
|
| 516 |
+
encoder_in = torch.cat([past_in, t_in], axis=1)
|
| 517 |
+
|
| 518 |
+
encoder_out = self.encoder(encoder_in)
|
| 519 |
+
mean = encoder_out[:, : self.latent_size]
|
| 520 |
+
logvar = encoder_out[:, self.latent_size :]
|
| 521 |
+
var = torch.exp(logvar)
|
| 522 |
+
|
| 523 |
+
return mean, var
|
| 524 |
+
|
| 525 |
+
def rsample(self, mu, var):
|
| 526 |
+
"""
|
| 527 |
+
Return gaussian sample of (mu, var) using reparameterization trick.
|
| 528 |
+
"""
|
| 529 |
+
eps = torch.randn_like(mu)
|
| 530 |
+
z = mu + eps * torch.sqrt(var)
|
| 531 |
+
return z
|
| 532 |
+
|
| 533 |
+
def decode(self, z, past_in):
|
| 534 |
+
"""
|
| 535 |
+
Decodes prediction from the latent transition and past states
|
| 536 |
+
|
| 537 |
+
Input:
|
| 538 |
+
- z (B x latent_size)
|
| 539 |
+
- past_in (B x steps_in*D)
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
- decoder_out (B x steps_out*D)
|
| 543 |
+
"""
|
| 544 |
+
B = z.size(0)
|
| 545 |
+
decoder_in = torch.cat([past_in, z], axis=1)
|
| 546 |
+
decoder_out = self.decoder(decoder_in).reshape((B, 1, -1))
|
| 547 |
+
|
| 548 |
+
if self.output_delta:
|
| 549 |
+
# network output is the residual, add to the input to get final output
|
| 550 |
+
step_in = past_in.reshape((B, self.steps_in, -1))[
|
| 551 |
+
:, -1:, :
|
| 552 |
+
] # most recent input step
|
| 553 |
+
|
| 554 |
+
final_out_list = []
|
| 555 |
+
in_sidx = out_sidx = 0
|
| 556 |
+
decode_out_dim_list = self.output_dim_list
|
| 557 |
+
if self.pred_contacts:
|
| 558 |
+
decode_out_dim_list = decode_out_dim_list[:-1] # do contacts separately
|
| 559 |
+
for in_dim_idx, out_dim_idx, data_name in zip(
|
| 560 |
+
self.input_dim_list, decode_out_dim_list, self.data_names
|
| 561 |
+
):
|
| 562 |
+
in_eidx = in_sidx + in_dim_idx
|
| 563 |
+
out_eidx = out_sidx + out_dim_idx
|
| 564 |
+
|
| 565 |
+
# add residual to input (and transform as necessary for rotations)
|
| 566 |
+
in_val = step_in[:, :, in_sidx:in_eidx]
|
| 567 |
+
out_val = decoder_out[:, :, out_sidx:out_eidx]
|
| 568 |
+
if data_name in ["root_orient", "pose_body"]:
|
| 569 |
+
if self.in_rot_rep != "mat":
|
| 570 |
+
in_val = convert_to_rotmat(in_val, rep=self.in_rot_rep)
|
| 571 |
+
out_val = convert_to_rotmat(out_val, rep=self.out_rot_rep)
|
| 572 |
+
|
| 573 |
+
in_val = in_val.reshape((B, 1, -1, 3, 3))
|
| 574 |
+
out_val = out_val.reshape((B, self.steps_out, -1, 3, 3))
|
| 575 |
+
|
| 576 |
+
rot_in = torch.matmul(out_val, in_val).reshape(
|
| 577 |
+
(B, self.steps_out, -1)
|
| 578 |
+
) # rotate by predicted residual
|
| 579 |
+
final_out_list.append(rot_in)
|
| 580 |
+
else:
|
| 581 |
+
final_out_list.append(out_val + in_val)
|
| 582 |
+
|
| 583 |
+
in_sidx = in_eidx
|
| 584 |
+
out_sidx = out_eidx
|
| 585 |
+
if self.pred_contacts:
|
| 586 |
+
final_out_list.append(decoder_out[:, :, out_sidx:])
|
| 587 |
+
|
| 588 |
+
decoder_out = torch.cat(final_out_list, dim=2)
|
| 589 |
+
|
| 590 |
+
decoder_out = decoder_out.reshape((B, -1))
|
| 591 |
+
|
| 592 |
+
return decoder_out
|
| 593 |
+
|
| 594 |
+
def scheduled_sampling(
|
| 595 |
+
self,
|
| 596 |
+
x_past,
|
| 597 |
+
x_t,
|
| 598 |
+
init_input_dict,
|
| 599 |
+
p=0.5,
|
| 600 |
+
gender=None,
|
| 601 |
+
betas=None,
|
| 602 |
+
need_global_out=True,
|
| 603 |
+
):
|
| 604 |
+
"""
|
| 605 |
+
Given all inputs and ground truth outputs for all steps, roll out model predictions
|
| 606 |
+
where at each step use the GT input with prob p, otherwise use own previous output.
|
| 607 |
+
|
| 608 |
+
Input:
|
| 609 |
+
- x_past (B x T x steps_in x D)
|
| 610 |
+
- x_t (B x T x steps_out x D)
|
| 611 |
+
- init_input_dict : dictionary of each initial state (B x steps_in x D), rotations should be matrices
|
| 612 |
+
- p : probability of using the GT input at each step of the sequence
|
| 613 |
+
- gender/betas only required if self.use_smpl_joint_inputs is true (used to decide the SMPL body model)
|
| 614 |
+
"""
|
| 615 |
+
B, T, S, D = x_past.size()
|
| 616 |
+
S_out = x_t.size(2)
|
| 617 |
+
J = len(SMPL_JOINTS)
|
| 618 |
+
cur_input_dict = init_input_dict # this is the predicted input dict
|
| 619 |
+
|
| 620 |
+
# initial input must be from GT since we don't have any predictions yet
|
| 621 |
+
past_in = x_past[:, 0, :, :].reshape((B, -1))
|
| 622 |
+
t_in = x_t[:, 0, :, :].reshape((B, -1))
|
| 623 |
+
|
| 624 |
+
global_world2local_rot = (
|
| 625 |
+
torch.eye(3).reshape((1, 1, 3, 3)).expand((B, 1, 3, 3)).to(x_past)
|
| 626 |
+
)
|
| 627 |
+
global_world2local_trans = torch.zeros((B, 1, 3)).to(x_past)
|
| 628 |
+
trans2joint = torch.zeros((B, 1, 1, 3)).to(x_past)
|
| 629 |
+
if self.need_trans2joint:
|
| 630 |
+
trans2joint = -torch.cat(
|
| 631 |
+
[cur_input_dict["joints"][:, -1, :2], torch.zeros((B, 1)).to(x_past)],
|
| 632 |
+
axis=1,
|
| 633 |
+
).reshape(
|
| 634 |
+
(B, 1, 1, 3)
|
| 635 |
+
) # same for whole sequence
|
| 636 |
+
pred_local_seq = []
|
| 637 |
+
pred_global_seq = []
|
| 638 |
+
for t in range(T):
|
| 639 |
+
# sample next step from model
|
| 640 |
+
x_pred_dict = self.single_step(past_in, t_in)
|
| 641 |
+
|
| 642 |
+
# save output
|
| 643 |
+
pred_local_seq.append(x_pred_dict)
|
| 644 |
+
|
| 645 |
+
# output is the actual regressed joints, but input to next step can use smpl joints
|
| 646 |
+
x_pred_smpl_joints = None
|
| 647 |
+
if self.use_smpl_joint_inputs and gender is not None and betas is not None:
|
| 648 |
+
# this assumes the model is actually outputting everything we need to run SMPL
|
| 649 |
+
# also assumes single output step
|
| 650 |
+
smpl_trans = x_pred_dict["trans"][:, 0:1].reshape(
|
| 651 |
+
(B, 3)
|
| 652 |
+
) # only want immediate next frame
|
| 653 |
+
smpl_root_orient = rotation_matrix_to_angle_axis(
|
| 654 |
+
x_pred_dict["root_orient"][:, 0:1].reshape((B, 3, 3))
|
| 655 |
+
).reshape((B, 3))
|
| 656 |
+
smpl_betas = betas[:, 0, :]
|
| 657 |
+
smpl_pose_body = rotation_matrix_to_angle_axis(
|
| 658 |
+
x_pred_dict["pose_body"][:, 0:1].reshape((B * (J - 1), 3, 3))
|
| 659 |
+
).reshape((B, (J - 1) * 3))
|
| 660 |
+
|
| 661 |
+
smpl_vals = [smpl_trans, smpl_root_orient, smpl_betas, smpl_pose_body]
|
| 662 |
+
# batch may be a mix of genders, so need to carefully use the corresponding SMPL body model
|
| 663 |
+
gender_names = ["male", "female", "neutral"]
|
| 664 |
+
pred_joints = []
|
| 665 |
+
prev_nbidx = 0
|
| 666 |
+
cat_idx_map = np.ones((B), dtype=np.int) * -1
|
| 667 |
+
for gender_name in gender_names:
|
| 668 |
+
gender_idx = np.array(gender) == gender_name
|
| 669 |
+
nbidx = np.sum(gender_idx)
|
| 670 |
+
|
| 671 |
+
cat_idx_map[gender_idx] = np.arange(
|
| 672 |
+
prev_nbidx, prev_nbidx + nbidx, dtype=np.int
|
| 673 |
+
)
|
| 674 |
+
prev_nbidx += nbidx
|
| 675 |
+
|
| 676 |
+
gender_smpl_vals = [val[gender_idx] for val in smpl_vals]
|
| 677 |
+
|
| 678 |
+
# need to pad extra frames with zeros in case not as long as expected
|
| 679 |
+
pad_size = self.smpl_batch_size - nbidx
|
| 680 |
+
if pad_size == B:
|
| 681 |
+
# skip if no frames for this gender
|
| 682 |
+
continue
|
| 683 |
+
pad_list = gender_smpl_vals
|
| 684 |
+
if pad_size < 0:
|
| 685 |
+
raise Exception(
|
| 686 |
+
"SMPL model batch size not large enough to accomodate!"
|
| 687 |
+
)
|
| 688 |
+
elif pad_size > 0:
|
| 689 |
+
pad_list = self.zero_pad_tensors(pad_list, pad_size)
|
| 690 |
+
|
| 691 |
+
# reconstruct SMPL
|
| 692 |
+
cur_pred_trans, cur_pred_orient, cur_betas, cur_pred_pose = pad_list
|
| 693 |
+
bm = self.bm_dict[gender_name]
|
| 694 |
+
pred_body = bm(
|
| 695 |
+
pose_body=cur_pred_pose,
|
| 696 |
+
betas=cur_betas,
|
| 697 |
+
root_orient=cur_pred_orient,
|
| 698 |
+
trans=cur_pred_trans,
|
| 699 |
+
)
|
| 700 |
+
if pad_size > 0:
|
| 701 |
+
pred_joints.append(pred_body.Jtr[:-pad_size])
|
| 702 |
+
else:
|
| 703 |
+
pred_joints.append(pred_body.Jtr)
|
| 704 |
+
|
| 705 |
+
# cat all genders and reorder to original batch ordering
|
| 706 |
+
x_pred_smpl_joints = torch.cat(pred_joints, axis=0)[
|
| 707 |
+
:, : len(SMPL_JOINTS), :
|
| 708 |
+
].reshape((B, 1, -1))
|
| 709 |
+
x_pred_smpl_joints = x_pred_smpl_joints[cat_idx_map]
|
| 710 |
+
|
| 711 |
+
# prepare predicted input to next step in case needed
|
| 712 |
+
# update input dict with new frame
|
| 713 |
+
del_keys = []
|
| 714 |
+
for k in cur_input_dict.keys():
|
| 715 |
+
if k in x_pred_dict:
|
| 716 |
+
# drop oldest frame and add new prediction
|
| 717 |
+
keep_frames = cur_input_dict[k][:, 1:, :]
|
| 718 |
+
# print(keep_frames.size())
|
| 719 |
+
if (
|
| 720 |
+
k == "joints"
|
| 721 |
+
and self.use_smpl_joint_inputs
|
| 722 |
+
and x_pred_smpl_joints is not None
|
| 723 |
+
):
|
| 724 |
+
# print('Using SMPL joints rather than regressed joints...')
|
| 725 |
+
if self.detach_sched_samp:
|
| 726 |
+
cur_input_dict[k] = torch.cat(
|
| 727 |
+
[keep_frames, x_pred_smpl_joints.detach()], axis=1
|
| 728 |
+
)
|
| 729 |
+
else:
|
| 730 |
+
cur_input_dict[k] = torch.cat(
|
| 731 |
+
[keep_frames, x_pred_smpl_joints], axis=1
|
| 732 |
+
)
|
| 733 |
+
else:
|
| 734 |
+
if self.detach_sched_samp:
|
| 735 |
+
cur_input_dict[k] = torch.cat(
|
| 736 |
+
[keep_frames, x_pred_dict[k][:, 0:1, :].detach()],
|
| 737 |
+
axis=1,
|
| 738 |
+
)
|
| 739 |
+
else:
|
| 740 |
+
cur_input_dict[k] = torch.cat(
|
| 741 |
+
[keep_frames, x_pred_dict[k][:, 0:1, :]], axis=1
|
| 742 |
+
)
|
| 743 |
+
# print(cur_input_dict[k].size())
|
| 744 |
+
else:
|
| 745 |
+
del_keys.append(k)
|
| 746 |
+
for k in del_keys:
|
| 747 |
+
del cur_input_dict[k] # don't need it anymore
|
| 748 |
+
|
| 749 |
+
# get world2aligned rot and translation
|
| 750 |
+
if self.detach_sched_samp:
|
| 751 |
+
root_orient_mat = (
|
| 752 |
+
x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3)).detach()
|
| 753 |
+
)
|
| 754 |
+
world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
|
| 755 |
+
world2aligned_trans = torch.cat(
|
| 756 |
+
[
|
| 757 |
+
-x_pred_dict["trans"][:, 0, :2].detach(),
|
| 758 |
+
torch.zeros((B, 1)).to(x_past),
|
| 759 |
+
],
|
| 760 |
+
axis=1,
|
| 761 |
+
)
|
| 762 |
+
else:
|
| 763 |
+
root_orient_mat = x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3))
|
| 764 |
+
world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
|
| 765 |
+
world2aligned_trans = torch.cat(
|
| 766 |
+
[-x_pred_dict["trans"][:, 0, :2], torch.zeros((B, 1)).to(x_past)],
|
| 767 |
+
axis=1,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
#
|
| 771 |
+
# transform inputs to this local frame for next step
|
| 772 |
+
#
|
| 773 |
+
cur_input_dict = self.apply_world2local_trans(
|
| 774 |
+
world2aligned_trans,
|
| 775 |
+
world2aligned_rot,
|
| 776 |
+
trans2joint,
|
| 777 |
+
cur_input_dict,
|
| 778 |
+
cur_input_dict,
|
| 779 |
+
invert=False,
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
# convert rots to correct input format
|
| 783 |
+
if self.in_rot_rep == "aa":
|
| 784 |
+
if "root_orient" in self.data_names:
|
| 785 |
+
cur_input_dict["root_orient"] = rotation_matrix_to_angle_axis(
|
| 786 |
+
cur_input_dict["root_orient"].reshape((B * S, 3, 3))
|
| 787 |
+
).reshape((B, S, 3))
|
| 788 |
+
if "pose_body" in self.data_names:
|
| 789 |
+
cur_input_dict["pose_body"] = rotation_matrix_to_angle_axis(
|
| 790 |
+
cur_input_dict["pose_body"].reshape((B * S * (J - 1), 3, 3))
|
| 791 |
+
).reshape((B, S, (J - 1) * 3))
|
| 792 |
+
elif self.in_rot_rep == "6d":
|
| 793 |
+
if "root_orient" in self.data_names:
|
| 794 |
+
cur_input_dict["root_orient"] = cur_input_dict["root_orient"][
|
| 795 |
+
:, :, :6
|
| 796 |
+
]
|
| 797 |
+
if "pose_body" in self.data_names:
|
| 798 |
+
cur_input_dict["pose_body"] = (
|
| 799 |
+
cur_input_dict["pose_body"]
|
| 800 |
+
.reshape((B, S, J - 1, 9))[:, :, :, :6]
|
| 801 |
+
.reshape((B, S, (J - 1) * 6))
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
if need_global_out:
|
| 805 |
+
#
|
| 806 |
+
# compute current world output and update world2local transform
|
| 807 |
+
#
|
| 808 |
+
cur_world_dict = dict()
|
| 809 |
+
cur_world_dict = self.apply_world2local_trans(
|
| 810 |
+
global_world2local_trans,
|
| 811 |
+
global_world2local_rot,
|
| 812 |
+
trans2joint,
|
| 813 |
+
x_pred_dict,
|
| 814 |
+
cur_world_dict,
|
| 815 |
+
invert=True,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
if self.detach_sched_samp:
|
| 819 |
+
global_world2local_trans = torch.cat(
|
| 820 |
+
[
|
| 821 |
+
-cur_world_dict["trans"][:, 0:1, :2].detach(),
|
| 822 |
+
torch.zeros((B, 1, 1)).to(x_past),
|
| 823 |
+
],
|
| 824 |
+
axis=2,
|
| 825 |
+
)
|
| 826 |
+
else:
|
| 827 |
+
global_world2local_trans = torch.cat(
|
| 828 |
+
[
|
| 829 |
+
-cur_world_dict["trans"][:, 0:1, :2],
|
| 830 |
+
torch.zeros((B, 1, 1)).to(x_past),
|
| 831 |
+
],
|
| 832 |
+
axis=2,
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
global_world2local_rot = torch.matmul(
|
| 836 |
+
global_world2local_rot, world2aligned_rot.reshape((B, 1, 3, 3))
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
pred_global_seq.append(cur_world_dict)
|
| 840 |
+
|
| 841 |
+
if t + 1 < T:
|
| 842 |
+
# choose whether next step will use GT or predicted inputs and prepare them
|
| 843 |
+
if np.random.random_sample() < p:
|
| 844 |
+
# use GT
|
| 845 |
+
past_in = x_past[:, t + 1, :, :].reshape((B, -1))
|
| 846 |
+
else:
|
| 847 |
+
# cat all inputs together to form past_in
|
| 848 |
+
in_data_list = []
|
| 849 |
+
for k in self.data_names:
|
| 850 |
+
in_data_list.append(cur_input_dict[k])
|
| 851 |
+
past_in = torch.cat(in_data_list, axis=2)
|
| 852 |
+
past_in = past_in.reshape((B, -1))
|
| 853 |
+
|
| 854 |
+
# GT output is the same no matter what
|
| 855 |
+
t_in = x_t[:, t + 1, :, :].reshape((B, -1))
|
| 856 |
+
|
| 857 |
+
if need_global_out:
|
| 858 |
+
# aggregate pred_seq
|
| 859 |
+
pred_global_seq_out = dict()
|
| 860 |
+
for k in pred_global_seq[0].keys():
|
| 861 |
+
if k == "posterior_distrib" or k == "prior_distrib":
|
| 862 |
+
m = torch.stack(
|
| 863 |
+
[pred_global_seq[i][k][0] for i in range(len(pred_global_seq))],
|
| 864 |
+
axis=1,
|
| 865 |
+
)
|
| 866 |
+
v = torch.stack(
|
| 867 |
+
[pred_global_seq[i][k][1] for i in range(len(pred_global_seq))],
|
| 868 |
+
axis=1,
|
| 869 |
+
)
|
| 870 |
+
pred_global_seq_out[k] = (m, v)
|
| 871 |
+
else:
|
| 872 |
+
pred_global_seq_out[k] = torch.stack(
|
| 873 |
+
[pred_global_seq[i][k] for i in range(len(pred_global_seq))],
|
| 874 |
+
axis=1,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
# aggregate pred_seq
|
| 878 |
+
pred_local_seq_out = dict()
|
| 879 |
+
for k in pred_local_seq[0].keys():
|
| 880 |
+
# print(k)
|
| 881 |
+
if k == "posterior_distrib" or k == "prior_distrib":
|
| 882 |
+
m = torch.stack(
|
| 883 |
+
[pred_local_seq[i][k][0] for i in range(len(pred_local_seq))],
|
| 884 |
+
axis=1,
|
| 885 |
+
)
|
| 886 |
+
v = torch.stack(
|
| 887 |
+
[pred_local_seq[i][k][1] for i in range(len(pred_local_seq))],
|
| 888 |
+
axis=1,
|
| 889 |
+
)
|
| 890 |
+
pred_local_seq_out[k] = (m, v)
|
| 891 |
+
else:
|
| 892 |
+
pred_local_seq_out[k] = torch.stack(
|
| 893 |
+
[pred_local_seq[i][k] for i in range(len(pred_local_seq))], axis=1
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
if need_global_out:
|
| 897 |
+
return pred_global_seq_out, pred_local_seq_out
|
| 898 |
+
else:
|
| 899 |
+
return pred_local_seq_out
|
| 900 |
+
|
| 901 |
+
def apply_world2local_trans(
|
| 902 |
+
self,
|
| 903 |
+
world2local_trans,
|
| 904 |
+
world2local_rot,
|
| 905 |
+
trans2joint,
|
| 906 |
+
input_dict,
|
| 907 |
+
output_dict,
|
| 908 |
+
invert=False,
|
| 909 |
+
):
|
| 910 |
+
"""
|
| 911 |
+
Applies the given world2local transformation to the data in input_dict and stores the result in output_dict.
|
| 912 |
+
|
| 913 |
+
If invert is true, applies local2world.
|
| 914 |
+
|
| 915 |
+
- world2local_trans : B x 3 or B x 1 x 3
|
| 916 |
+
- world2local_rot : B x 3 x 3 or B x 1 x 3 x 3
|
| 917 |
+
- trans2joint : B x 1 x 1 x 3
|
| 918 |
+
"""
|
| 919 |
+
B = world2local_trans.size(0)
|
| 920 |
+
world2local_rot = world2local_rot.reshape((B, 1, 3, 3))
|
| 921 |
+
world2local_trans = world2local_trans.reshape((B, 1, 3))
|
| 922 |
+
trans2joint = trans2joint.reshape((B, 1, 1, 3))
|
| 923 |
+
if invert:
|
| 924 |
+
local2world_rot = world2local_rot.transpose(3, 2)
|
| 925 |
+
for k, v in input_dict.items():
|
| 926 |
+
# apply differently depending on which data value it is
|
| 927 |
+
if k not in WORLD2ALIGN_NAME_CACHE:
|
| 928 |
+
# frame of reference is irrelevant, just copy to output
|
| 929 |
+
output_dict[k] = input_dict[k]
|
| 930 |
+
continue
|
| 931 |
+
|
| 932 |
+
S = input_dict[k].size(1)
|
| 933 |
+
if k in ["root_orient"]:
|
| 934 |
+
# rot: B x S x 3 x 3 sized rotation matrix input
|
| 935 |
+
input_mat = input_dict[k].reshape(
|
| 936 |
+
(B, S, 3, 3)
|
| 937 |
+
) # make sure not B x S x 9
|
| 938 |
+
if invert:
|
| 939 |
+
output_dict[k] = torch.matmul(local2world_rot, input_mat).reshape(
|
| 940 |
+
(B, S, 9)
|
| 941 |
+
)
|
| 942 |
+
else:
|
| 943 |
+
output_dict[k] = torch.matmul(world2local_rot, input_mat).reshape(
|
| 944 |
+
(B, S, 9)
|
| 945 |
+
)
|
| 946 |
+
elif k in ["trans"]:
|
| 947 |
+
# trans + rot : B x S x 3
|
| 948 |
+
input_trans = input_dict[k]
|
| 949 |
+
if invert:
|
| 950 |
+
output_trans = torch.matmul(
|
| 951 |
+
local2world_rot, input_trans.reshape((B, S, 3, 1))
|
| 952 |
+
)[:, :, :, 0]
|
| 953 |
+
output_trans = output_trans - world2local_trans
|
| 954 |
+
output_dict[k] = output_trans
|
| 955 |
+
else:
|
| 956 |
+
input_trans = input_trans + world2local_trans
|
| 957 |
+
output_dict[k] = torch.matmul(
|
| 958 |
+
world2local_rot, input_trans.reshape((B, S, 3, 1))
|
| 959 |
+
)[:, :, :, 0]
|
| 960 |
+
elif k in ["joints", "verts"]:
|
| 961 |
+
# trans + joint + rot : B x S x J x 3
|
| 962 |
+
J = input_dict[k].size(2) // 3
|
| 963 |
+
input_pts = input_dict[k].reshape((B, S, J, 3))
|
| 964 |
+
if invert:
|
| 965 |
+
input_pts = input_pts + trans2joint
|
| 966 |
+
output_pts = torch.matmul(
|
| 967 |
+
local2world_rot.reshape((B, 1, 1, 3, 3)),
|
| 968 |
+
input_pts.reshape((B, S, J, 3, 1)),
|
| 969 |
+
)[:, :, :, :, 0]
|
| 970 |
+
output_pts = (
|
| 971 |
+
output_pts
|
| 972 |
+
- trans2joint
|
| 973 |
+
- world2local_trans.reshape((B, 1, 1, 3))
|
| 974 |
+
)
|
| 975 |
+
output_dict[k] = output_pts.reshape((B, S, J * 3))
|
| 976 |
+
else:
|
| 977 |
+
input_pts = (
|
| 978 |
+
input_pts
|
| 979 |
+
+ world2local_trans.reshape((B, 1, 1, 3))
|
| 980 |
+
+ trans2joint
|
| 981 |
+
)
|
| 982 |
+
output_pts = torch.matmul(
|
| 983 |
+
world2local_rot.reshape((B, 1, 1, 3, 3)),
|
| 984 |
+
input_pts.reshape((B, S, J, 3, 1)),
|
| 985 |
+
)[:, :, :, :, 0]
|
| 986 |
+
output_pts = output_pts - trans2joint
|
| 987 |
+
output_dict[k] = output_pts.reshape((B, S, J * 3))
|
| 988 |
+
elif k in ["joints_vel", "verts_vel"]:
|
| 989 |
+
# rot : B x S x J x 3
|
| 990 |
+
J = input_dict[k].size(2) // 3
|
| 991 |
+
input_pts = input_dict[k].reshape((B, S, J, 3, 1))
|
| 992 |
+
if invert:
|
| 993 |
+
outuput_pts = torch.matmul(
|
| 994 |
+
local2world_rot.reshape((B, 1, 1, 3, 3)), input_pts
|
| 995 |
+
)[:, :, :, :, 0]
|
| 996 |
+
output_dict[k] = outuput_pts.reshape((B, S, J * 3))
|
| 997 |
+
else:
|
| 998 |
+
output_pts = torch.matmul(
|
| 999 |
+
world2local_rot.reshape((B, 1, 1, 3, 3)), input_pts
|
| 1000 |
+
)[:, :, :, :, 0]
|
| 1001 |
+
output_dict[k] = output_pts.reshape((B, S, J * 3))
|
| 1002 |
+
elif k in ["trans_vel", "root_orient_vel"]:
|
| 1003 |
+
# rot : B x S x 3
|
| 1004 |
+
input_pts = input_dict[k].reshape((B, S, 3, 1))
|
| 1005 |
+
if invert:
|
| 1006 |
+
output_dict[k] = torch.matmul(local2world_rot, input_pts)[
|
| 1007 |
+
:, :, :, 0
|
| 1008 |
+
]
|
| 1009 |
+
else:
|
| 1010 |
+
output_dict[k] = torch.matmul(world2local_rot, input_pts)[
|
| 1011 |
+
:, :, :, 0
|
| 1012 |
+
]
|
| 1013 |
+
else:
|
| 1014 |
+
print(
|
| 1015 |
+
"Received an unexpected key when transforming world2local: %s!"
|
| 1016 |
+
% (k)
|
| 1017 |
+
)
|
| 1018 |
+
exit()
|
| 1019 |
+
|
| 1020 |
+
return output_dict
|
| 1021 |
+
|
| 1022 |
+
def zero_pad_tensors(self, pad_list, pad_size):
|
| 1023 |
+
"""
|
| 1024 |
+
Assumes tensors in pad_list are B x D
|
| 1025 |
+
"""
|
| 1026 |
+
new_pad_list = []
|
| 1027 |
+
for pad_idx, pad_tensor in enumerate(pad_list):
|
| 1028 |
+
padding = torch.zeros((pad_size, pad_tensor.size(1))).to(pad_tensor)
|
| 1029 |
+
new_pad_list.append(torch.cat([pad_tensor, padding], dim=0))
|
| 1030 |
+
return new_pad_list
|
| 1031 |
+
|
| 1032 |
+
def roll_out(
|
| 1033 |
+
self,
|
| 1034 |
+
x_past,
|
| 1035 |
+
init_input_dict,
|
| 1036 |
+
num_steps,
|
| 1037 |
+
use_mean=False,
|
| 1038 |
+
z_seq=None,
|
| 1039 |
+
return_prior=False,
|
| 1040 |
+
gender=None,
|
| 1041 |
+
betas=None,
|
| 1042 |
+
return_z=False,
|
| 1043 |
+
canonicalize_input=False,
|
| 1044 |
+
uncanonicalize_output=False,
|
| 1045 |
+
):
|
| 1046 |
+
"""
|
| 1047 |
+
Given input for first step, roll out using own output the entire time by sampling from the prior.
|
| 1048 |
+
Returns the global trajectory.
|
| 1049 |
+
|
| 1050 |
+
Input:
|
| 1051 |
+
- x_past (B x steps_in x D_in)
|
| 1052 |
+
- initial_input_dict : dictionary of each initial state (B x steps_in x D), rotations should be matrices
|
| 1053 |
+
(assumes initial state is already in its local coordinate system (translation at [0,0,z] and aligned))
|
| 1054 |
+
- num_steps : the number of timesteps to roll out
|
| 1055 |
+
- use_mean : if True, uses the mean of latent distribution instead of sampling
|
| 1056 |
+
- z_seq : (B x steps_out x D) if given, uses as the latent input to decoder at each step rather than sampling
|
| 1057 |
+
- return_prior : if True, also returns the output of the conditional prior at each step
|
| 1058 |
+
-gender : list of e.g. ['male', 'female', etc..] of length B
|
| 1059 |
+
-betas : B x steps_in x D
|
| 1060 |
+
-return_z : returns the sampled z sequence in addition to the output
|
| 1061 |
+
- canonicalize_input : if true, the input initial state is assumed to not be in the local aligned coordinate system. It will be transformed before using.
|
| 1062 |
+
- uncanonicalize_output : if true and canonicalize_input=True, will transform output back into the input frame rather than return in canonical frame.
|
| 1063 |
+
Returns:
|
| 1064 |
+
- x_pred - dict of (B x num_steps x D_out) for each value. Rotations are all matrices.
|
| 1065 |
+
"""
|
| 1066 |
+
J = len(SMPL_JOINTS)
|
| 1067 |
+
cur_input_dict = init_input_dict
|
| 1068 |
+
|
| 1069 |
+
# need to transform init input to local frame
|
| 1070 |
+
world2aligned_rot = world2aligned_trans = None
|
| 1071 |
+
if canonicalize_input:
|
| 1072 |
+
B, _, _ = cur_input_dict[list(cur_input_dict.keys())[0]].size()
|
| 1073 |
+
# must transform initial input into the local frame
|
| 1074 |
+
# get world2aligned rot and translation
|
| 1075 |
+
root_orient_mat = cur_input_dict["root_orient"]
|
| 1076 |
+
pose_body_mat = cur_input_dict["pose_body"]
|
| 1077 |
+
if "root_orient" in self.data_names and self.in_rot_rep != "mat":
|
| 1078 |
+
root_orient_mat = convert_to_rotmat(
|
| 1079 |
+
root_orient_mat, rep=self.in_rot_rep
|
| 1080 |
+
)
|
| 1081 |
+
if "pose_body" in self.data_names and self.in_rot_rep != "mat":
|
| 1082 |
+
pose_body_mat = convert_to_rotmat(pose_body_mat, rep=self.in_rot_rep)
|
| 1083 |
+
|
| 1084 |
+
root_orient_mat = root_orient_mat[:, -1].reshape((B, 3, 3))
|
| 1085 |
+
world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
|
| 1086 |
+
world2aligned_trans = torch.cat(
|
| 1087 |
+
[
|
| 1088 |
+
-cur_input_dict["trans"][:, -1, :2],
|
| 1089 |
+
torch.zeros((B, 1)).to(root_orient_mat),
|
| 1090 |
+
],
|
| 1091 |
+
axis=1,
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
# compute trans2joint
|
| 1095 |
+
if self.need_trans2joint:
|
| 1096 |
+
trans2joint = -(
|
| 1097 |
+
cur_input_dict["joints"][:, -1, :2] + world2aligned_trans[:, :2]
|
| 1098 |
+
)
|
| 1099 |
+
trans2joint = torch.cat(
|
| 1100 |
+
[trans2joint, torch.zeros((B, 1)).to(trans2joint)], axis=1
|
| 1101 |
+
).reshape((B, 1, 1, 3))
|
| 1102 |
+
|
| 1103 |
+
# transform to local frame
|
| 1104 |
+
cur_input_dict = self.apply_world2local_trans(
|
| 1105 |
+
world2aligned_trans,
|
| 1106 |
+
world2aligned_rot,
|
| 1107 |
+
trans2joint,
|
| 1108 |
+
cur_input_dict,
|
| 1109 |
+
cur_input_dict,
|
| 1110 |
+
invert=False,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# check to make sure we have enough input steps, if not, pad
|
| 1114 |
+
pad_x_past = x_past is not None and x_past.size(1) < self.steps_in
|
| 1115 |
+
pad_in_dict = (
|
| 1116 |
+
cur_input_dict[list(cur_input_dict.keys())[0]].size(1) < self.steps_in
|
| 1117 |
+
)
|
| 1118 |
+
if pad_x_past:
|
| 1119 |
+
num_pad_steps = self.steps_in - x_past.size(1)
|
| 1120 |
+
cur_padding = torch.zeros(
|
| 1121 |
+
(x_past.size(0), num_pad_steps, x_past.size(2))
|
| 1122 |
+
).to(
|
| 1123 |
+
x_past
|
| 1124 |
+
) # assuming all data is B x T x D
|
| 1125 |
+
x_past = torch.cat([cur_padding, x_past], axis=1)
|
| 1126 |
+
if pad_in_dict:
|
| 1127 |
+
for k in cur_input_dict.keys():
|
| 1128 |
+
cur_in_dat = cur_input_dict[k]
|
| 1129 |
+
num_pad_steps = self.steps_in - cur_in_dat.size(1)
|
| 1130 |
+
cur_padding = torch.zeros(
|
| 1131 |
+
(cur_in_dat.size(0), num_pad_steps, cur_in_dat.size(2))
|
| 1132 |
+
).to(
|
| 1133 |
+
cur_in_dat
|
| 1134 |
+
) # assuming all data is B x T x D
|
| 1135 |
+
padded_in_dat = torch.cat([cur_padding, cur_in_dat], axis=1)
|
| 1136 |
+
cur_input_dict[k] = padded_in_dat
|
| 1137 |
+
|
| 1138 |
+
if x_past is None or canonicalize_input:
|
| 1139 |
+
x_past = [cur_input_dict[k] for k in self.data_names]
|
| 1140 |
+
x_past = torch.cat(x_past, axis=2)
|
| 1141 |
+
B, S, D = x_past.size()
|
| 1142 |
+
past_in = x_past.reshape((B, -1))
|
| 1143 |
+
|
| 1144 |
+
global_world2local_rot = (
|
| 1145 |
+
torch.eye(3).reshape((1, 1, 3, 3)).expand((B, 1, 3, 3)).to(x_past)
|
| 1146 |
+
)
|
| 1147 |
+
global_world2local_trans = torch.zeros((B, 1, 3)).to(x_past)
|
| 1148 |
+
if canonicalize_input and uncanonicalize_output:
|
| 1149 |
+
global_world2local_rot = world2aligned_rot.unsqueeze(1)
|
| 1150 |
+
global_world2local_trans = world2aligned_trans.unsqueeze(1)
|
| 1151 |
+
trans2joint = torch.zeros((B, 1, 1, 3)).to(x_past)
|
| 1152 |
+
if self.need_trans2joint:
|
| 1153 |
+
trans2joint = -torch.cat(
|
| 1154 |
+
[cur_input_dict["joints"][:, -1, :2], torch.zeros((B, 1)).to(x_past)],
|
| 1155 |
+
axis=1,
|
| 1156 |
+
).reshape(
|
| 1157 |
+
(B, 1, 1, 3)
|
| 1158 |
+
) # same for whole sequence
|
| 1159 |
+
pred_local_seq = []
|
| 1160 |
+
pred_global_seq = []
|
| 1161 |
+
prior_seq = []
|
| 1162 |
+
z_out_seq = []
|
| 1163 |
+
for t in range(num_steps):
|
| 1164 |
+
x_pred_dict = None
|
| 1165 |
+
# sample next step
|
| 1166 |
+
z_in = None
|
| 1167 |
+
if z_seq is not None:
|
| 1168 |
+
z_in = z_seq[:, t]
|
| 1169 |
+
sample_out = self.sample_step(
|
| 1170 |
+
past_in,
|
| 1171 |
+
use_mean=use_mean,
|
| 1172 |
+
z=z_in,
|
| 1173 |
+
return_prior=return_prior,
|
| 1174 |
+
return_z=return_z,
|
| 1175 |
+
)
|
| 1176 |
+
if return_prior:
|
| 1177 |
+
prior_out = sample_out["prior"]
|
| 1178 |
+
prior_seq.append(prior_out)
|
| 1179 |
+
if return_z:
|
| 1180 |
+
z_out = sample_out["z"]
|
| 1181 |
+
z_out_seq.append(z_out)
|
| 1182 |
+
decoder_out = sample_out["decoder_out"]
|
| 1183 |
+
|
| 1184 |
+
# split output predictions and transform out rotations to matrices
|
| 1185 |
+
x_pred_dict = self.split_output(decoder_out, convert_rots=True)
|
| 1186 |
+
if self.steps_out > 1:
|
| 1187 |
+
for k in x_pred_dict.keys():
|
| 1188 |
+
# only want immediate next frame prediction
|
| 1189 |
+
x_pred_dict[k] = x_pred_dict[k][:, 0:1, :]
|
| 1190 |
+
|
| 1191 |
+
pred_local_seq.append(x_pred_dict)
|
| 1192 |
+
|
| 1193 |
+
# output is the actual regressed joints, but input to next step can use smpl joints
|
| 1194 |
+
x_pred_smpl_joints = None
|
| 1195 |
+
if self.use_smpl_joint_inputs and gender is not None and betas is not None:
|
| 1196 |
+
# this assumes the model is actually outputting everything we need to run SMPL
|
| 1197 |
+
# also assumes single output step
|
| 1198 |
+
smpl_trans = x_pred_dict["trans"].reshape((B, 3))
|
| 1199 |
+
smpl_root_orient = rotation_matrix_to_angle_axis(
|
| 1200 |
+
x_pred_dict["root_orient"].reshape((B, 3, 3))
|
| 1201 |
+
).reshape((B, 3))
|
| 1202 |
+
smpl_betas = betas[:, 0, :]
|
| 1203 |
+
smpl_pose_body = rotation_matrix_to_angle_axis(
|
| 1204 |
+
x_pred_dict["pose_body"].reshape((B * (J - 1), 3, 3))
|
| 1205 |
+
).reshape((B, (J - 1) * 3))
|
| 1206 |
+
|
| 1207 |
+
smpl_vals = [smpl_trans, smpl_root_orient, smpl_betas, smpl_pose_body]
|
| 1208 |
+
# each batch index may be a different gender
|
| 1209 |
+
gender_names = ["male", "female", "neutral"]
|
| 1210 |
+
pred_joints = []
|
| 1211 |
+
prev_nbidx = 0
|
| 1212 |
+
cat_idx_map = np.ones((B), dtype=np.int) * -1
|
| 1213 |
+
for gender_name in gender_names:
|
| 1214 |
+
gender_idx = np.array(gender) == gender_name
|
| 1215 |
+
nbidx = np.sum(gender_idx)
|
| 1216 |
+
cat_idx_map[gender_idx] = np.arange(
|
| 1217 |
+
prev_nbidx, prev_nbidx + nbidx, dtype=np.int
|
| 1218 |
+
)
|
| 1219 |
+
prev_nbidx += nbidx
|
| 1220 |
+
|
| 1221 |
+
gender_smpl_vals = [val[gender_idx] for val in smpl_vals]
|
| 1222 |
+
|
| 1223 |
+
# need to pad extra frames with zeros in case not as long as expected
|
| 1224 |
+
pad_size = self.smpl_batch_size - nbidx
|
| 1225 |
+
if pad_size == B:
|
| 1226 |
+
# skip if no frames for this gender
|
| 1227 |
+
continue
|
| 1228 |
+
pad_list = gender_smpl_vals
|
| 1229 |
+
if pad_size < 0:
|
| 1230 |
+
raise Exception(
|
| 1231 |
+
"SMPL model batch size not large enough to accomodate!"
|
| 1232 |
+
)
|
| 1233 |
+
elif pad_size > 0:
|
| 1234 |
+
pad_list = self.zero_pad_tensors(pad_list, pad_size)
|
| 1235 |
+
|
| 1236 |
+
# reconstruct SMPL
|
| 1237 |
+
cur_pred_trans, cur_pred_orient, cur_betas, cur_pred_pose = pad_list
|
| 1238 |
+
bm = self.bm_dict[gender_name]
|
| 1239 |
+
pred_body = bm(
|
| 1240 |
+
pose_body=cur_pred_pose,
|
| 1241 |
+
betas=cur_betas,
|
| 1242 |
+
root_orient=cur_pred_orient,
|
| 1243 |
+
trans=cur_pred_trans,
|
| 1244 |
+
)
|
| 1245 |
+
if pad_size > 0:
|
| 1246 |
+
pred_joints.append(pred_body.Jtr[:-pad_size])
|
| 1247 |
+
else:
|
| 1248 |
+
pred_joints.append(pred_body.Jtr)
|
| 1249 |
+
|
| 1250 |
+
# cat all genders and reorder to original batch ordering
|
| 1251 |
+
x_pred_smpl_joints = torch.cat(pred_joints, axis=0)[
|
| 1252 |
+
:, : len(SMPL_JOINTS), :
|
| 1253 |
+
].reshape((B, 1, -1))
|
| 1254 |
+
x_pred_smpl_joints = x_pred_smpl_joints[cat_idx_map]
|
| 1255 |
+
|
| 1256 |
+
# prepare input to next step
|
| 1257 |
+
# update input dict with new frame
|
| 1258 |
+
del_keys = []
|
| 1259 |
+
for k in cur_input_dict.keys():
|
| 1260 |
+
if k in x_pred_dict:
|
| 1261 |
+
# drop oldest frame and add new prediction
|
| 1262 |
+
keep_frames = cur_input_dict[k][:, 1:, :]
|
| 1263 |
+
# print(keep_frames.size())
|
| 1264 |
+
|
| 1265 |
+
if (
|
| 1266 |
+
k == "joints"
|
| 1267 |
+
and self.use_smpl_joint_inputs
|
| 1268 |
+
and x_pred_smpl_joints is not None
|
| 1269 |
+
):
|
| 1270 |
+
cur_input_dict[k] = torch.cat(
|
| 1271 |
+
[keep_frames, x_pred_smpl_joints], axis=1
|
| 1272 |
+
)
|
| 1273 |
+
else:
|
| 1274 |
+
cur_input_dict[k] = torch.cat(
|
| 1275 |
+
[keep_frames, x_pred_dict[k]], axis=1
|
| 1276 |
+
)
|
| 1277 |
+
else:
|
| 1278 |
+
del_keys.append(k)
|
| 1279 |
+
for k in del_keys:
|
| 1280 |
+
del cur_input_dict[k]
|
| 1281 |
+
|
| 1282 |
+
# get world2aligned rot and translation
|
| 1283 |
+
root_orient_mat = x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3))
|
| 1284 |
+
world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
|
| 1285 |
+
world2aligned_trans = torch.cat(
|
| 1286 |
+
[-x_pred_dict["trans"][:, 0, :2], torch.zeros((B, 1)).to(x_past)],
|
| 1287 |
+
axis=1,
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
#
|
| 1291 |
+
# transform inputs to this local frame (body pose is not affected) for next step
|
| 1292 |
+
#
|
| 1293 |
+
cur_input_dict = self.apply_world2local_trans(
|
| 1294 |
+
world2aligned_trans,
|
| 1295 |
+
world2aligned_rot,
|
| 1296 |
+
trans2joint,
|
| 1297 |
+
cur_input_dict,
|
| 1298 |
+
cur_input_dict,
|
| 1299 |
+
invert=False,
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
# convert rots to correct input format
|
| 1303 |
+
if self.in_rot_rep == "aa":
|
| 1304 |
+
if "root_orient" in self.data_names:
|
| 1305 |
+
cur_input_dict["root_orient"] = rotation_matrix_to_angle_axis(
|
| 1306 |
+
cur_input_dict["root_orient"].reshape((B * S, 3, 3))
|
| 1307 |
+
).reshape((B, S, 3))
|
| 1308 |
+
if "pose_body" in self.data_names:
|
| 1309 |
+
cur_input_dict["pose_body"] = rotation_matrix_to_angle_axis(
|
| 1310 |
+
cur_input_dict["pose_body"].reshape((B * S * (J - 1), 3, 3))
|
| 1311 |
+
).reshape((B, S, (J - 1) * 3))
|
| 1312 |
+
elif self.in_rot_rep == "6d":
|
| 1313 |
+
if "root_orient" in self.data_names:
|
| 1314 |
+
cur_input_dict["root_orient"] = cur_input_dict["root_orient"][
|
| 1315 |
+
:, :, :6
|
| 1316 |
+
]
|
| 1317 |
+
if "pose_body" in self.data_names:
|
| 1318 |
+
cur_input_dict["pose_body"] = (
|
| 1319 |
+
cur_input_dict["pose_body"]
|
| 1320 |
+
.reshape((B, S, J - 1, 9))[:, :, :, :6]
|
| 1321 |
+
.reshape((B, S, (J - 1) * 6))
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
#
|
| 1325 |
+
# compute current world output and update world2local transform
|
| 1326 |
+
#
|
| 1327 |
+
cur_world_dict = dict()
|
| 1328 |
+
cur_world_dict = self.apply_world2local_trans(
|
| 1329 |
+
global_world2local_trans,
|
| 1330 |
+
global_world2local_rot,
|
| 1331 |
+
trans2joint,
|
| 1332 |
+
x_pred_dict,
|
| 1333 |
+
cur_world_dict,
|
| 1334 |
+
invert=True,
|
| 1335 |
+
)
|
| 1336 |
+
#
|
| 1337 |
+
# update world2local transform
|
| 1338 |
+
#
|
| 1339 |
+
global_world2local_trans = torch.cat(
|
| 1340 |
+
[
|
| 1341 |
+
-cur_world_dict["trans"][:, 0:1, :2],
|
| 1342 |
+
torch.zeros((B, 1, 1)).to(x_past),
|
| 1343 |
+
],
|
| 1344 |
+
axis=2,
|
| 1345 |
+
)
|
| 1346 |
+
# print(world2aligned_rot)
|
| 1347 |
+
global_world2local_rot = torch.matmul(
|
| 1348 |
+
global_world2local_rot, world2aligned_rot.reshape((B, 1, 3, 3))
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
pred_global_seq.append(cur_world_dict)
|
| 1352 |
+
|
| 1353 |
+
# cat all inputs together to form past_in
|
| 1354 |
+
in_data_list = []
|
| 1355 |
+
for k in self.data_names:
|
| 1356 |
+
in_data_list.append(cur_input_dict[k])
|
| 1357 |
+
past_in = torch.cat(in_data_list, axis=2)
|
| 1358 |
+
past_in = past_in.reshape((B, -1))
|
| 1359 |
+
|
| 1360 |
+
# aggregate global pred_seq
|
| 1361 |
+
pred_seq_out = dict()
|
| 1362 |
+
for k in pred_global_seq[0].keys():
|
| 1363 |
+
pred_seq_out[k] = torch.cat(
|
| 1364 |
+
[pred_global_seq[i][k] for i in range(len(pred_global_seq))], axis=1
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
if return_z:
|
| 1368 |
+
z_out_seq = torch.stack(z_out_seq, dim=1)
|
| 1369 |
+
pred_seq_out["z"] = z_out_seq
|
| 1370 |
+
|
| 1371 |
+
if return_prior:
|
| 1372 |
+
pm = torch.stack([prior_seq[i][0] for i in range(len(prior_seq))], axis=1)
|
| 1373 |
+
pv = torch.stack([prior_seq[i][1] for i in range(len(prior_seq))], axis=1)
|
| 1374 |
+
return pred_seq_out, (pm, pv)
|
| 1375 |
+
else:
|
| 1376 |
+
return pred_seq_out
|
| 1377 |
+
|
| 1378 |
+
def sample_step(
|
| 1379 |
+
self,
|
| 1380 |
+
past_in,
|
| 1381 |
+
t_in=None,
|
| 1382 |
+
use_mean=False,
|
| 1383 |
+
z=None,
|
| 1384 |
+
return_prior=False,
|
| 1385 |
+
return_z=False,
|
| 1386 |
+
):
|
| 1387 |
+
"""
|
| 1388 |
+
Given past, samples next future state by sampling from prior or posterior and decoding.
|
| 1389 |
+
If z (B x D) is not None, uses the given z instead of sampling from posterior or prior
|
| 1390 |
+
|
| 1391 |
+
Returns:
|
| 1392 |
+
- decoder_out : (B x steps_out x D) output of the decoder for the immediate next step
|
| 1393 |
+
"""
|
| 1394 |
+
B = past_in.size(0)
|
| 1395 |
+
|
| 1396 |
+
pm, pv = None, None
|
| 1397 |
+
if t_in is not None:
|
| 1398 |
+
# use past and future to encode latent transition
|
| 1399 |
+
pm, pv = self.posterior(past_in, t_in)
|
| 1400 |
+
else:
|
| 1401 |
+
# prior
|
| 1402 |
+
if self.use_conditional_prior:
|
| 1403 |
+
# predict prior based on past
|
| 1404 |
+
pm, pv = self.prior(past_in)
|
| 1405 |
+
else:
|
| 1406 |
+
# use standard normal
|
| 1407 |
+
pm, pv = torch.zeros((B, self.latent_size)).to(past_in), torch.ones(
|
| 1408 |
+
(B, self.latent_size)
|
| 1409 |
+
).to(past_in)
|
| 1410 |
+
|
| 1411 |
+
# sample from distrib or use mean
|
| 1412 |
+
if z is None:
|
| 1413 |
+
if not use_mean:
|
| 1414 |
+
z = self.rsample(pm, pv)
|
| 1415 |
+
else:
|
| 1416 |
+
z = pm # NOTE: use mean
|
| 1417 |
+
|
| 1418 |
+
# decode to get next step
|
| 1419 |
+
decoder_out = self.decode(z, past_in)
|
| 1420 |
+
decoder_out = decoder_out.reshape(
|
| 1421 |
+
(B, self.steps_out, -1)
|
| 1422 |
+
) # B x steps_out x D_out
|
| 1423 |
+
|
| 1424 |
+
out_dict = {"decoder_out": decoder_out}
|
| 1425 |
+
if return_prior:
|
| 1426 |
+
out_dict["prior"] = (pm, pv)
|
| 1427 |
+
if return_z:
|
| 1428 |
+
out_dict["z"] = z
|
| 1429 |
+
|
| 1430 |
+
return out_dict
|
| 1431 |
+
|
| 1432 |
+
def infer_global_seq(self, global_seq, full_forward_pass=False):
|
| 1433 |
+
"""
|
| 1434 |
+
Given a sequence of global states, formats it (transform each step into local frame and makde B x steps_in x D)
|
| 1435 |
+
and runs inference (compute prior/posterior of z for the sequence).
|
| 1436 |
+
|
| 1437 |
+
If full_forward_pass is true, does an entire forward pass at each step rather than just inference.
|
| 1438 |
+
|
| 1439 |
+
Rotations should be in in_rot_rep format.
|
| 1440 |
+
"""
|
| 1441 |
+
# used to compute output zero padding
|
| 1442 |
+
needed_future_steps = (self.steps_out - 1) * self.out_step_size
|
| 1443 |
+
|
| 1444 |
+
prior_m_seq = []
|
| 1445 |
+
prior_v_seq = []
|
| 1446 |
+
post_m_seq = []
|
| 1447 |
+
post_v_seq = []
|
| 1448 |
+
pred_dict_seq = []
|
| 1449 |
+
B, T, _ = global_seq[list(global_seq.keys())[0]].size()
|
| 1450 |
+
J = len(SMPL_JOINTS)
|
| 1451 |
+
trans2joint = None
|
| 1452 |
+
for t in range(T - 1):
|
| 1453 |
+
# get world2aligned rot and translation
|
| 1454 |
+
world2aligned_rot = world2aligned_trans = None
|
| 1455 |
+
|
| 1456 |
+
root_orient_mat = global_seq["root_orient"][:, t, :].reshape((B, 3, 3))
|
| 1457 |
+
world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
|
| 1458 |
+
world2aligned_trans = torch.cat(
|
| 1459 |
+
[
|
| 1460 |
+
-global_seq["trans"][:, t, :2],
|
| 1461 |
+
torch.zeros((B, 1)).to(root_orient_mat),
|
| 1462 |
+
],
|
| 1463 |
+
axis=1,
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
# compute trans2joint at first step
|
| 1467 |
+
if t == 0 and self.need_trans2joint:
|
| 1468 |
+
trans2joint = -(
|
| 1469 |
+
global_seq["joints"][:, t, :2] + world2aligned_trans[:, :2]
|
| 1470 |
+
) # we cannot make the assumption that the first frame is already canonical
|
| 1471 |
+
trans2joint = torch.cat(
|
| 1472 |
+
[trans2joint, torch.zeros((B, 1)).to(trans2joint)], axis=1
|
| 1473 |
+
).reshape((B, 1, 1, 3))
|
| 1474 |
+
|
| 1475 |
+
# get current window
|
| 1476 |
+
cur_data_dict = dict()
|
| 1477 |
+
for k in global_seq.keys():
|
| 1478 |
+
# get in steps
|
| 1479 |
+
in_sidx = max(0, t - self.steps_in + 1)
|
| 1480 |
+
cur_in_seq = global_seq[k][:, in_sidx : (t + 1), :]
|
| 1481 |
+
if cur_in_seq.size(1) < self.steps_in:
|
| 1482 |
+
# must zero pad front
|
| 1483 |
+
num_pad_steps = self.steps_in - cur_in_seq.size(1)
|
| 1484 |
+
cur_padding = torch.zeros(
|
| 1485 |
+
(cur_in_seq.size(0), num_pad_steps, cur_in_seq.size(2))
|
| 1486 |
+
).to(
|
| 1487 |
+
cur_in_seq
|
| 1488 |
+
) # assuming all data is B x T x D
|
| 1489 |
+
cur_in_seq = torch.cat([cur_padding, cur_in_seq], axis=1)
|
| 1490 |
+
|
| 1491 |
+
# get out steps
|
| 1492 |
+
cur_out_seq = global_seq[k][
|
| 1493 |
+
:, (t + 1) : (t + 2 + needed_future_steps) : self.out_step_size
|
| 1494 |
+
]
|
| 1495 |
+
if cur_out_seq.size(1) < self.steps_out:
|
| 1496 |
+
# zero pad
|
| 1497 |
+
num_pad_steps = self.steps_out - cur_out_seq.size(1)
|
| 1498 |
+
cur_padding = torch.zeros_like(cur_out_seq[:, 0])
|
| 1499 |
+
cur_padding = torch.stack([cur_padding] * num_pad_steps, axis=1)
|
| 1500 |
+
cur_out_seq = torch.cat([cur_out_seq, cur_padding], axis=1)
|
| 1501 |
+
cur_data_dict[k] = torch.cat([cur_in_seq, cur_out_seq], axis=1)
|
| 1502 |
+
|
| 1503 |
+
# transform to local frame
|
| 1504 |
+
cur_data_dict = self.apply_world2local_trans(
|
| 1505 |
+
world2aligned_trans,
|
| 1506 |
+
world2aligned_rot,
|
| 1507 |
+
trans2joint,
|
| 1508 |
+
cur_data_dict,
|
| 1509 |
+
cur_data_dict,
|
| 1510 |
+
invert=False,
|
| 1511 |
+
)
|
| 1512 |
+
|
| 1513 |
+
# create x_past and x_t
|
| 1514 |
+
# cat all inputs together to form past_in
|
| 1515 |
+
in_data_list = []
|
| 1516 |
+
for k in self.data_names:
|
| 1517 |
+
in_data_list.append(cur_data_dict[k][:, : self.steps_in, :])
|
| 1518 |
+
x_past = torch.cat(in_data_list, axis=2)
|
| 1519 |
+
# cat all outputs together to form x_t
|
| 1520 |
+
out_data_list = []
|
| 1521 |
+
for k in self.data_names:
|
| 1522 |
+
out_data_list.append(cur_data_dict[k][:, self.steps_in :, :])
|
| 1523 |
+
x_t = torch.cat(out_data_list, axis=2)
|
| 1524 |
+
|
| 1525 |
+
if full_forward_pass:
|
| 1526 |
+
x_pred_dict = self(x_past, x_t)
|
| 1527 |
+
pred_dict_seq.append(x_pred_dict)
|
| 1528 |
+
else:
|
| 1529 |
+
# perform inference
|
| 1530 |
+
prior_z, posterior_z = self.infer(x_past, x_t)
|
| 1531 |
+
# save z
|
| 1532 |
+
prior_m_seq.append(prior_z[0])
|
| 1533 |
+
prior_v_seq.append(prior_z[1])
|
| 1534 |
+
post_m_seq.append(posterior_z[0])
|
| 1535 |
+
post_v_seq.append(posterior_z[1])
|
| 1536 |
+
|
| 1537 |
+
if full_forward_pass:
|
| 1538 |
+
# pred_dict_seq
|
| 1539 |
+
pred_seq_out = dict()
|
| 1540 |
+
for k in pred_dict_seq[0].keys():
|
| 1541 |
+
# print(k)
|
| 1542 |
+
if k == "posterior_distrib" or k == "prior_distrib":
|
| 1543 |
+
m = torch.stack(
|
| 1544 |
+
[pred_dict_seq[i][k][0] for i in range(len(pred_dict_seq))],
|
| 1545 |
+
axis=1,
|
| 1546 |
+
)
|
| 1547 |
+
v = torch.stack(
|
| 1548 |
+
[pred_dict_seq[i][k][1] for i in range(len(pred_dict_seq))],
|
| 1549 |
+
axis=1,
|
| 1550 |
+
)
|
| 1551 |
+
pred_seq_out[k] = (m, v)
|
| 1552 |
+
else:
|
| 1553 |
+
pred_seq_out[k] = torch.stack(
|
| 1554 |
+
[pred_dict_seq[i][k] for i in range(len(pred_dict_seq))], axis=1
|
| 1555 |
+
)
|
| 1556 |
+
|
| 1557 |
+
return pred_seq_out
|
| 1558 |
+
else:
|
| 1559 |
+
prior_m_seq = torch.stack(prior_m_seq, axis=1)
|
| 1560 |
+
prior_v_seq = torch.stack(prior_v_seq, axis=1)
|
| 1561 |
+
post_m_seq = torch.stack(post_m_seq, axis=1)
|
| 1562 |
+
post_v_seq = torch.stack(post_v_seq, axis=1)
|
| 1563 |
+
|
| 1564 |
+
return (prior_m_seq, prior_v_seq), (post_m_seq, post_v_seq)
|
| 1565 |
+
|
| 1566 |
+
def infer(self, x_past, x_t):
|
| 1567 |
+
"""
|
| 1568 |
+
Inference (compute prior and posterior distribution of z) for a batch of single steps.
|
| 1569 |
+
NOTE: must do processing before passing in to ensure correct format that this function expects.
|
| 1570 |
+
|
| 1571 |
+
Input:
|
| 1572 |
+
- x_past (B x steps_in x D)
|
| 1573 |
+
- x_t (B x steps_out x D)
|
| 1574 |
+
|
| 1575 |
+
Returns:
|
| 1576 |
+
- prior_distrib (mu, var)
|
| 1577 |
+
- posterior_distrib (mu, var)
|
| 1578 |
+
"""
|
| 1579 |
+
|
| 1580 |
+
B, _, D = x_past.size()
|
| 1581 |
+
past_in = x_past.reshape((B, -1))
|
| 1582 |
+
t_in = x_t.reshape((B, -1))
|
| 1583 |
+
|
| 1584 |
+
prior_z, posterior_z = self.infer_step(past_in, t_in)
|
| 1585 |
+
|
| 1586 |
+
return prior_z, posterior_z
|
| 1587 |
+
|
| 1588 |
+
def infer_step(self, past_in, t_in):
|
| 1589 |
+
"""
|
| 1590 |
+
single step that computes both prior and posterior for training. Samples from posterior
|
| 1591 |
+
"""
|
| 1592 |
+
B = past_in.size(0)
|
| 1593 |
+
# use past and future to encode latent transition
|
| 1594 |
+
qm, qv = self.posterior(past_in, t_in)
|
| 1595 |
+
|
| 1596 |
+
# prior
|
| 1597 |
+
pm, pv = None, None
|
| 1598 |
+
if self.use_conditional_prior:
|
| 1599 |
+
# predict prior based on past
|
| 1600 |
+
pm, pv = self.prior(past_in)
|
| 1601 |
+
else:
|
| 1602 |
+
# use standard normal
|
| 1603 |
+
pm, pv = torch.zeros_like(qm), torch.ones_like(qv)
|
| 1604 |
+
|
| 1605 |
+
return (pm, pv), (qm, qv)
|
| 1606 |
+
|
| 1607 |
+
|
| 1608 |
+
class MLP(nn.Module):
|
| 1609 |
+
def __init__(
|
| 1610 |
+
self,
|
| 1611 |
+
layers=[3, 128, 128, 3],
|
| 1612 |
+
nonlinearity=nn.ReLU,
|
| 1613 |
+
use_gn=True,
|
| 1614 |
+
skip_input_idx=None,
|
| 1615 |
+
):
|
| 1616 |
+
"""
|
| 1617 |
+
If skip_input_idx is not None, the input feature after idx skip_input_idx will be skip connected to every later of the MLP.
|
| 1618 |
+
"""
|
| 1619 |
+
super(MLP, self).__init__()
|
| 1620 |
+
|
| 1621 |
+
in_size = layers[0]
|
| 1622 |
+
out_channels = layers[1:]
|
| 1623 |
+
|
| 1624 |
+
# input layer
|
| 1625 |
+
layers = []
|
| 1626 |
+
layers.append(nn.Linear(in_size, out_channels[0]))
|
| 1627 |
+
skip_size = 0 if skip_input_idx is None else (in_size - skip_input_idx)
|
| 1628 |
+
# now the rest
|
| 1629 |
+
for layer_idx in range(1, len(out_channels)):
|
| 1630 |
+
fc_layer = nn.Linear(
|
| 1631 |
+
out_channels[layer_idx - 1] + skip_size, out_channels[layer_idx]
|
| 1632 |
+
)
|
| 1633 |
+
if use_gn:
|
| 1634 |
+
bn_layer = nn.GroupNorm(16, out_channels[layer_idx - 1])
|
| 1635 |
+
layers.append(bn_layer)
|
| 1636 |
+
layers.extend([nonlinearity(), fc_layer])
|
| 1637 |
+
self.net = nn.ModuleList(layers)
|
| 1638 |
+
self.skip_input_idx = skip_input_idx
|
| 1639 |
+
|
| 1640 |
+
def forward(self, x):
|
| 1641 |
+
"""
|
| 1642 |
+
B x D x * : batch norm done over dim D
|
| 1643 |
+
"""
|
| 1644 |
+
skip_in = None
|
| 1645 |
+
if self.skip_input_idx is not None:
|
| 1646 |
+
skip_in = x[:, self.skip_input_idx :]
|
| 1647 |
+
for i, layer in enumerate(self.net):
|
| 1648 |
+
if (
|
| 1649 |
+
self.skip_input_idx is not None
|
| 1650 |
+
and i > 0
|
| 1651 |
+
and isinstance(layer, nn.Linear)
|
| 1652 |
+
):
|
| 1653 |
+
x = torch.cat([x, skip_in], dim=1)
|
| 1654 |
+
x = layer(x)
|
| 1655 |
+
return x
|
slahmr/slahmr/humor/transforms.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Taken from https://github.com/davrempe/humor
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from body_model.utils import SMPL_JOINTS
|
| 12 |
+
|
| 13 |
+
#
|
| 14 |
+
# For computing local body frame
|
| 15 |
+
#
|
| 16 |
+
|
| 17 |
+
GLOB_DEVICE = (
|
| 18 |
+
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
| 19 |
+
)
|
| 20 |
+
XY_AXIS_GLOB = torch.Tensor([[1.0, 1.0, 0.0]]).to(device=GLOB_DEVICE)
|
| 21 |
+
X_AXIS_GLOB = torch.Tensor([[1.0, 0.0, 0.0]]).to(device=GLOB_DEVICE)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_aligned_from_right(body_right):
|
| 25 |
+
xy_axis = XY_AXIS_GLOB
|
| 26 |
+
x_axis = X_AXIS_GLOB
|
| 27 |
+
|
| 28 |
+
body_right_x_proj = body_right[:, 0:1] / (
|
| 29 |
+
torch.norm(body_right[:, :2], dim=1, keepdim=True) + 1e-6
|
| 30 |
+
)
|
| 31 |
+
body_right_x_proj = torch.clamp(
|
| 32 |
+
body_right_x_proj, min=-1.0, max=1.0
|
| 33 |
+
) # avoid acos error
|
| 34 |
+
|
| 35 |
+
world2aligned_angle = torch.acos(
|
| 36 |
+
body_right_x_proj
|
| 37 |
+
) # project to world x axis, and compute angle
|
| 38 |
+
body_right = body_right * xy_axis
|
| 39 |
+
world2aligned_axis = torch.linalg.cross(body_right, x_axis.expand_as(body_right))
|
| 40 |
+
|
| 41 |
+
world2aligned_aa = (
|
| 42 |
+
world2aligned_axis
|
| 43 |
+
/ (torch.norm(world2aligned_axis, dim=1, keepdim=True) + 1e-6)
|
| 44 |
+
) * world2aligned_angle
|
| 45 |
+
world2aligned_mat = batch_rodrigues(world2aligned_aa)
|
| 46 |
+
|
| 47 |
+
return world2aligned_mat, world2aligned_aa
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def compute_world2aligned_mat(rot_pos):
|
| 51 |
+
"""
|
| 52 |
+
batch of world rotation matrices: B x 3 x 3
|
| 53 |
+
returns rot mats that align the inputs to the forward direction: B x 3 x 3
|
| 54 |
+
Torch version
|
| 55 |
+
"""
|
| 56 |
+
body_right = -rot_pos[:, :, 0] # .clone() # in body coordinates body x-axis is left
|
| 57 |
+
|
| 58 |
+
world2aligned_mat, world2aligned_aa = compute_aligned_from_right(body_right)
|
| 59 |
+
return world2aligned_mat
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def compute_world2aligned_joints_mat(joints):
|
| 63 |
+
"""
|
| 64 |
+
Compute world to canonical frame (rotation around up axis)
|
| 65 |
+
from the given batch of joints (B x J x 3)
|
| 66 |
+
"""
|
| 67 |
+
left_idx = SMPL_JOINTS["leftUpLeg"]
|
| 68 |
+
right_idx = SMPL_JOINTS["rightUpLeg"]
|
| 69 |
+
|
| 70 |
+
body_right = joints[:, right_idx] - joints[:, left_idx]
|
| 71 |
+
body_right = body_right / torch.norm(body_right, dim=1, keepdim=True)
|
| 72 |
+
|
| 73 |
+
world2aligned_mat, world2aligned_aa = compute_aligned_from_right(body_right)
|
| 74 |
+
|
| 75 |
+
return world2aligned_mat
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def convert_to_rotmat(pred_rot, rep="aa"):
|
| 79 |
+
"""
|
| 80 |
+
Converts rotation rep to rotation matrix based on the given type.
|
| 81 |
+
pred_rot : B x T x N
|
| 82 |
+
"""
|
| 83 |
+
B, T, _ = pred_rot.size()
|
| 84 |
+
pred_rot_mat = None
|
| 85 |
+
if rep == "aa":
|
| 86 |
+
pred_rot_mat = batch_rodrigues(pred_rot.reshape(-1, 3))
|
| 87 |
+
elif rep == "6d":
|
| 88 |
+
pred_rot_mat = rot6d_to_rotmat(pred_rot.reshape(-1, 6))
|
| 89 |
+
elif rep == "9d":
|
| 90 |
+
pred_rot_mat = rot9d_to_rotmat(pred_rot.reshape(-1, 9))
|
| 91 |
+
return pred_rot_mat.reshape((B, T, -1))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
#
|
| 95 |
+
# Many of these functions taken from https://github.com/mkocabas/VIBE/blob/a859e45a907379aa2fba65a7b620b4a2d65dcf1b/lib/utils/geometry.py
|
| 96 |
+
# Please see their license for usage restrictions.
|
| 97 |
+
#
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def matrot2axisangle(matrots):
|
| 101 |
+
"""
|
| 102 |
+
:param matrots: N*num_joints*9
|
| 103 |
+
:return: N*num_joints*3
|
| 104 |
+
"""
|
| 105 |
+
import cv2
|
| 106 |
+
|
| 107 |
+
batch_size = matrots.shape[0]
|
| 108 |
+
matrots = matrots.reshape([batch_size, -1, 9])
|
| 109 |
+
out_axisangle = []
|
| 110 |
+
for mIdx in range(matrots.shape[0]):
|
| 111 |
+
cur_axisangle = []
|
| 112 |
+
for jIdx in range(matrots.shape[1]):
|
| 113 |
+
a = cv2.Rodrigues(matrots[mIdx, jIdx : jIdx + 1, :].reshape(3, 3))[
|
| 114 |
+
0
|
| 115 |
+
].reshape((1, 3))
|
| 116 |
+
cur_axisangle.append(a)
|
| 117 |
+
|
| 118 |
+
out_axisangle.append(np.array(cur_axisangle).reshape([1, -1, 3]))
|
| 119 |
+
return np.vstack(out_axisangle)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def axisangle2matrots(axisangle):
|
| 123 |
+
"""
|
| 124 |
+
:param axisangle: N*num_joints*3
|
| 125 |
+
:return: N*num_joints*9
|
| 126 |
+
"""
|
| 127 |
+
import cv2
|
| 128 |
+
|
| 129 |
+
batch_size = axisangle.shape[0]
|
| 130 |
+
axisangle = axisangle.reshape([batch_size, -1, 3])
|
| 131 |
+
out_matrot = []
|
| 132 |
+
for mIdx in range(axisangle.shape[0]):
|
| 133 |
+
cur_axisangle = []
|
| 134 |
+
for jIdx in range(axisangle.shape[1]):
|
| 135 |
+
a = cv2.Rodrigues(axisangle[mIdx, jIdx : jIdx + 1, :].reshape(1, 3))[0]
|
| 136 |
+
cur_axisangle.append(a)
|
| 137 |
+
|
| 138 |
+
out_matrot.append(np.array(cur_axisangle).reshape([1, -1, 9]))
|
| 139 |
+
return np.vstack(out_matrot)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def make_rot_homog(rotation_matrix):
|
| 143 |
+
if rotation_matrix.shape[1:] == (3, 3):
|
| 144 |
+
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
| 145 |
+
hom = (
|
| 146 |
+
torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
|
| 147 |
+
.reshape(1, 3, 1)
|
| 148 |
+
.expand(rot_mat.shape[0], -1, -1)
|
| 149 |
+
)
|
| 150 |
+
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
| 151 |
+
return rotation_matrix
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def skew(v):
|
| 155 |
+
"""
|
| 156 |
+
Returns skew symmetric (B x 3 x 3) mat from vector v: B x 3
|
| 157 |
+
"""
|
| 158 |
+
B, D = v.size()
|
| 159 |
+
assert D == 3
|
| 160 |
+
skew_mat = torch.zeros((B, 3, 3)).to(v)
|
| 161 |
+
skew_mat[:, 0, 1] = v[:, 2]
|
| 162 |
+
skew_mat[:, 1, 0] = -v[:, 2]
|
| 163 |
+
skew_mat[:, 0, 2] = v[:, 1]
|
| 164 |
+
skew_mat[:, 2, 0] = -v[:, 1]
|
| 165 |
+
skew_mat[:, 1, 2] = v[:, 0]
|
| 166 |
+
skew_mat[:, 2, 1] = -v[:, 0]
|
| 167 |
+
return skew_mat
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
| 171 |
+
"""Calculates the rotation matrices for a batch of rotation vectors
|
| 172 |
+
Parameters
|
| 173 |
+
----------
|
| 174 |
+
rot_vecs: torch.tensor Nx3
|
| 175 |
+
array of N axis-angle vectors
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
R: torch.tensor Nx3x3
|
| 179 |
+
The rotation matrices for the given axis-angle parameters
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
batch_size = rot_vecs.shape[0]
|
| 183 |
+
device = rot_vecs.device
|
| 184 |
+
|
| 185 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
| 186 |
+
rot_dir = rot_vecs / angle
|
| 187 |
+
|
| 188 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
| 189 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
| 190 |
+
|
| 191 |
+
# Bx1 arrays
|
| 192 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
| 193 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
| 194 |
+
|
| 195 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
| 196 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
|
| 197 |
+
(batch_size, 3, 3)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
| 201 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
| 202 |
+
return rot_mat
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def quat2mat(quat):
|
| 206 |
+
"""
|
| 207 |
+
This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50
|
| 208 |
+
Convert quaternion coefficients to rotation matrix.
|
| 209 |
+
Args:
|
| 210 |
+
quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
|
| 211 |
+
Returns:
|
| 212 |
+
Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
|
| 213 |
+
"""
|
| 214 |
+
norm_quat = quat
|
| 215 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
| 216 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
| 217 |
+
|
| 218 |
+
batch_size = quat.size(0)
|
| 219 |
+
|
| 220 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
| 221 |
+
wx, wy, wz = w * x, w * y, w * z
|
| 222 |
+
xy, xz, yz = x * y, x * z, y * z
|
| 223 |
+
|
| 224 |
+
rotMat = torch.stack(
|
| 225 |
+
[
|
| 226 |
+
w2 + x2 - y2 - z2,
|
| 227 |
+
2 * xy - 2 * wz,
|
| 228 |
+
2 * wy + 2 * xz,
|
| 229 |
+
2 * wz + 2 * xy,
|
| 230 |
+
w2 - x2 + y2 - z2,
|
| 231 |
+
2 * yz - 2 * wx,
|
| 232 |
+
2 * xz - 2 * wy,
|
| 233 |
+
2 * wx + 2 * yz,
|
| 234 |
+
w2 - x2 - y2 + z2,
|
| 235 |
+
],
|
| 236 |
+
dim=1,
|
| 237 |
+
).view(batch_size, 3, 3)
|
| 238 |
+
return rotMat
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def rot6d_to_rotmat(x):
|
| 242 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
| 243 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
| 244 |
+
Input:
|
| 245 |
+
(B,6) Batch of 6-D rotation representations
|
| 246 |
+
Output:
|
| 247 |
+
(B,3,3) Batch of corresponding rotation matrices
|
| 248 |
+
"""
|
| 249 |
+
x = x.view(-1, 3, 2)
|
| 250 |
+
a1 = x[:, :, 0]
|
| 251 |
+
a2 = x[:, :, 1]
|
| 252 |
+
b1 = F.normalize(a1)
|
| 253 |
+
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
|
| 254 |
+
|
| 255 |
+
# inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1
|
| 256 |
+
# denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8
|
| 257 |
+
# b2 = inp / denom
|
| 258 |
+
|
| 259 |
+
b3 = torch.linalg.cross(b1, b2)
|
| 260 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def rot9d_to_rotmat(x):
|
| 264 |
+
"""
|
| 265 |
+
Converts 9D rotation output to valid 3x3 rotation amtrix.
|
| 266 |
+
Based on Levinson et al., An Analysis of SVD for Deep Rotation Estimation.
|
| 267 |
+
|
| 268 |
+
Input:
|
| 269 |
+
(B, 9)
|
| 270 |
+
Output:
|
| 271 |
+
(B, 9)
|
| 272 |
+
"""
|
| 273 |
+
B = x.size()[0]
|
| 274 |
+
x = x.reshape((B, 3, 3))
|
| 275 |
+
u, s, v = torch.svd(x)
|
| 276 |
+
|
| 277 |
+
v_T = v.transpose(-2, -1)
|
| 278 |
+
s_p = torch.eye(3).to(x).reshape((1, 3, 3)).expand_as(x).clone()
|
| 279 |
+
s_p[:, 2, 2] = torch.det(torch.matmul(u, v_T))
|
| 280 |
+
x_out = torch.matmul(torch.matmul(u, s_p), v_T)
|
| 281 |
+
|
| 282 |
+
return x_out.reshape((B, 9))
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
| 286 |
+
"""
|
| 287 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 288 |
+
Convert 3x4 rotation matrix to Rodrigues vector
|
| 289 |
+
Args:
|
| 290 |
+
rotation_matrix (Tensor): rotation matrix.
|
| 291 |
+
Returns:
|
| 292 |
+
Tensor: Rodrigues vector transformation.
|
| 293 |
+
Shape:
|
| 294 |
+
- Input: :math:`(N, 3, 4)`
|
| 295 |
+
- Output: :math:`(N, 3)`
|
| 296 |
+
Example:
|
| 297 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
| 298 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
| 299 |
+
"""
|
| 300 |
+
if rotation_matrix.shape[1:] == (3, 3):
|
| 301 |
+
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
| 302 |
+
hom = (
|
| 303 |
+
torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
|
| 304 |
+
.reshape(1, 3, 1)
|
| 305 |
+
.expand(rot_mat.shape[0], -1, -1)
|
| 306 |
+
)
|
| 307 |
+
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
| 308 |
+
|
| 309 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
| 310 |
+
aa = quaternion_to_angle_axis(quaternion)
|
| 311 |
+
aa[torch.isnan(aa)] = 0.0
|
| 312 |
+
return aa
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
| 316 |
+
"""
|
| 317 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 318 |
+
Convert 3x4 rotation matrix to 4d quaternion vector
|
| 319 |
+
This algorithm is based on algorithm described in
|
| 320 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
| 321 |
+
Args:
|
| 322 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
| 323 |
+
Return:
|
| 324 |
+
Tensor: the rotation in quaternion
|
| 325 |
+
Shape:
|
| 326 |
+
- Input: :math:`(N, 3, 4)`
|
| 327 |
+
- Output: :math:`(N, 4)`
|
| 328 |
+
Example:
|
| 329 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
| 330 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
| 331 |
+
"""
|
| 332 |
+
if not torch.is_tensor(rotation_matrix):
|
| 333 |
+
raise TypeError(
|
| 334 |
+
"Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if len(rotation_matrix.shape) > 3:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
| 340 |
+
rotation_matrix.shape
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
| 346 |
+
rotation_matrix.shape
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
| 351 |
+
|
| 352 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
| 353 |
+
|
| 354 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
| 355 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
| 356 |
+
|
| 357 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 358 |
+
q0 = torch.stack(
|
| 359 |
+
[
|
| 360 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
| 361 |
+
t0,
|
| 362 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
| 363 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
| 364 |
+
],
|
| 365 |
+
-1,
|
| 366 |
+
)
|
| 367 |
+
t0_rep = t0.repeat(4, 1).t()
|
| 368 |
+
|
| 369 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 370 |
+
q1 = torch.stack(
|
| 371 |
+
[
|
| 372 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
| 373 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
| 374 |
+
t1,
|
| 375 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
| 376 |
+
],
|
| 377 |
+
-1,
|
| 378 |
+
)
|
| 379 |
+
t1_rep = t1.repeat(4, 1).t()
|
| 380 |
+
|
| 381 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 382 |
+
q2 = torch.stack(
|
| 383 |
+
[
|
| 384 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
| 385 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
| 386 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
| 387 |
+
t2,
|
| 388 |
+
],
|
| 389 |
+
-1,
|
| 390 |
+
)
|
| 391 |
+
t2_rep = t2.repeat(4, 1).t()
|
| 392 |
+
|
| 393 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 394 |
+
q3 = torch.stack(
|
| 395 |
+
[
|
| 396 |
+
t3,
|
| 397 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
| 398 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
| 399 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
| 400 |
+
],
|
| 401 |
+
-1,
|
| 402 |
+
)
|
| 403 |
+
t3_rep = t3.repeat(4, 1).t()
|
| 404 |
+
|
| 405 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
| 406 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
| 407 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
| 408 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
| 409 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
| 410 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
| 411 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
| 412 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
| 413 |
+
|
| 414 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
| 415 |
+
q /= torch.sqrt(
|
| 416 |
+
t0_rep * mask_c0
|
| 417 |
+
+ t1_rep * mask_c1
|
| 418 |
+
+ t2_rep * mask_c2 # noqa
|
| 419 |
+
+ t3_rep * mask_c3
|
| 420 |
+
) # noqa
|
| 421 |
+
q *= 0.5
|
| 422 |
+
return q
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
| 426 |
+
"""
|
| 427 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 428 |
+
Convert quaternion vector to angle axis of rotation.
|
| 429 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
| 430 |
+
Args:
|
| 431 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
| 432 |
+
Return:
|
| 433 |
+
torch.Tensor: tensor with angle axis of rotation.
|
| 434 |
+
Shape:
|
| 435 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
| 436 |
+
- Output: :math:`(*, 3)`
|
| 437 |
+
Example:
|
| 438 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
| 439 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
| 440 |
+
"""
|
| 441 |
+
if not torch.is_tensor(quaternion):
|
| 442 |
+
raise TypeError(
|
| 443 |
+
"Input type is not a torch.Tensor. Got {}".format(type(quaternion))
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
if not quaternion.shape[-1] == 4:
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
|
| 449 |
+
)
|
| 450 |
+
# unpack input and compute conversion
|
| 451 |
+
q1: torch.Tensor = quaternion[..., 1]
|
| 452 |
+
q2: torch.Tensor = quaternion[..., 2]
|
| 453 |
+
q3: torch.Tensor = quaternion[..., 3]
|
| 454 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
| 455 |
+
|
| 456 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
| 457 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
| 458 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
| 459 |
+
cos_theta < 0.0,
|
| 460 |
+
torch.atan2(-sin_theta, -cos_theta),
|
| 461 |
+
torch.atan2(sin_theta, cos_theta),
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
| 465 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
| 466 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
| 467 |
+
|
| 468 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
| 469 |
+
angle_axis[..., 0] += q1 * k
|
| 470 |
+
angle_axis[..., 1] += q2 * k
|
| 471 |
+
angle_axis[..., 2] += q3 * k
|
| 472 |
+
return angle_axis
|
slahmr/slahmr/job_specs/3dpw_test_split.txt
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
downtown_arguing_00 data.start_idx=0 data.end_idx=100
|
| 2 |
+
downtown_arguing_00 data.start_idx=100 data.end_idx=200
|
| 3 |
+
downtown_arguing_00 data.start_idx=200 data.end_idx=300
|
| 4 |
+
downtown_arguing_00 data.start_idx=300 data.end_idx=400
|
| 5 |
+
downtown_arguing_00 data.start_idx=400 data.end_idx=500
|
| 6 |
+
downtown_arguing_00 data.start_idx=500 data.end_idx=600
|
| 7 |
+
downtown_arguing_00 data.start_idx=600 data.end_idx=700
|
| 8 |
+
downtown_arguing_00 data.start_idx=700 data.end_idx=898
|
| 9 |
+
downtown_bar_00 data.start_idx=0 data.end_idx=100
|
| 10 |
+
downtown_bar_00 data.start_idx=100 data.end_idx=200
|
| 11 |
+
downtown_bar_00 data.start_idx=200 data.end_idx=300
|
| 12 |
+
downtown_bar_00 data.start_idx=300 data.end_idx=400
|
| 13 |
+
downtown_bar_00 data.start_idx=400 data.end_idx=500
|
| 14 |
+
downtown_bar_00 data.start_idx=500 data.end_idx=600
|
| 15 |
+
downtown_bar_00 data.start_idx=600 data.end_idx=700
|
| 16 |
+
downtown_bar_00 data.start_idx=700 data.end_idx=800
|
| 17 |
+
downtown_bar_00 data.start_idx=800 data.end_idx=900
|
| 18 |
+
downtown_bar_00 data.start_idx=900 data.end_idx=1000
|
| 19 |
+
downtown_bar_00 data.start_idx=1000 data.end_idx=1100
|
| 20 |
+
downtown_bar_00 data.start_idx=1100 data.end_idx=1200
|
| 21 |
+
downtown_bar_00 data.start_idx=1200 data.end_idx=1300
|
| 22 |
+
downtown_bar_00 data.start_idx=1300 data.end_idx=1403
|
| 23 |
+
downtown_bus_00 data.start_idx=0 data.end_idx=100
|
| 24 |
+
downtown_bus_00 data.start_idx=100 data.end_idx=200
|
| 25 |
+
downtown_bus_00 data.start_idx=200 data.end_idx=300
|
| 26 |
+
downtown_bus_00 data.start_idx=300 data.end_idx=400
|
| 27 |
+
downtown_bus_00 data.start_idx=400 data.end_idx=500
|
| 28 |
+
downtown_bus_00 data.start_idx=500 data.end_idx=600
|
| 29 |
+
downtown_bus_00 data.start_idx=600 data.end_idx=700
|
| 30 |
+
downtown_bus_00 data.start_idx=700 data.end_idx=800
|
| 31 |
+
downtown_bus_00 data.start_idx=800 data.end_idx=900
|
| 32 |
+
downtown_bus_00 data.start_idx=900 data.end_idx=1000
|
| 33 |
+
downtown_bus_00 data.start_idx=1000 data.end_idx=1100
|
| 34 |
+
downtown_bus_00 data.start_idx=1100 data.end_idx=1200
|
| 35 |
+
downtown_bus_00 data.start_idx=1200 data.end_idx=1300
|
| 36 |
+
downtown_bus_00 data.start_idx=1300 data.end_idx=1400
|
| 37 |
+
downtown_bus_00 data.start_idx=1400 data.end_idx=1500
|
| 38 |
+
downtown_bus_00 data.start_idx=1500 data.end_idx=1600
|
| 39 |
+
downtown_bus_00 data.start_idx=1600 data.end_idx=1700
|
| 40 |
+
downtown_bus_00 data.start_idx=1700 data.end_idx=1800
|
| 41 |
+
downtown_bus_00 data.start_idx=1800 data.end_idx=1900
|
| 42 |
+
downtown_bus_00 data.start_idx=1900 data.end_idx=2000
|
| 43 |
+
downtown_bus_00 data.start_idx=2000 data.end_idx=2178
|
| 44 |
+
downtown_cafe_00 data.start_idx=0 data.end_idx=100
|
| 45 |
+
downtown_cafe_00 data.start_idx=100 data.end_idx=200
|
| 46 |
+
downtown_cafe_00 data.start_idx=200 data.end_idx=300
|
| 47 |
+
downtown_cafe_00 data.start_idx=300 data.end_idx=400
|
| 48 |
+
downtown_cafe_00 data.start_idx=400 data.end_idx=500
|
| 49 |
+
downtown_cafe_00 data.start_idx=500 data.end_idx=600
|
| 50 |
+
downtown_cafe_00 data.start_idx=600 data.end_idx=700
|
| 51 |
+
downtown_cafe_00 data.start_idx=700 data.end_idx=800
|
| 52 |
+
downtown_cafe_00 data.start_idx=800 data.end_idx=900
|
| 53 |
+
downtown_cafe_00 data.start_idx=900 data.end_idx=1000
|
| 54 |
+
downtown_cafe_00 data.start_idx=1000 data.end_idx=1100
|
| 55 |
+
downtown_cafe_00 data.start_idx=1100 data.end_idx=1201
|
| 56 |
+
downtown_car_00 data.start_idx=0 data.end_idx=100
|
| 57 |
+
downtown_car_00 data.start_idx=100 data.end_idx=200
|
| 58 |
+
downtown_car_00 data.start_idx=200 data.end_idx=300
|
| 59 |
+
downtown_car_00 data.start_idx=300 data.end_idx=400
|
| 60 |
+
downtown_car_00 data.start_idx=400 data.end_idx=500
|
| 61 |
+
downtown_car_00 data.start_idx=500 data.end_idx=600
|
| 62 |
+
downtown_car_00 data.start_idx=600 data.end_idx=700
|
| 63 |
+
downtown_car_00 data.start_idx=700 data.end_idx=800
|
| 64 |
+
downtown_car_00 data.start_idx=800 data.end_idx=900
|
| 65 |
+
downtown_car_00 data.start_idx=900 data.end_idx=1020
|
| 66 |
+
downtown_crossStreets_00 data.start_idx=0 data.end_idx=100
|
| 67 |
+
downtown_crossStreets_00 data.start_idx=100 data.end_idx=200
|
| 68 |
+
downtown_crossStreets_00 data.start_idx=200 data.end_idx=300
|
| 69 |
+
downtown_crossStreets_00 data.start_idx=300 data.end_idx=400
|
| 70 |
+
downtown_crossStreets_00 data.start_idx=400 data.end_idx=588
|
| 71 |
+
downtown_downstairs_00 data.start_idx=0 data.end_idx=100
|
| 72 |
+
downtown_downstairs_00 data.start_idx=100 data.end_idx=200
|
| 73 |
+
downtown_downstairs_00 data.start_idx=200 data.end_idx=300
|
| 74 |
+
downtown_downstairs_00 data.start_idx=300 data.end_idx=400
|
| 75 |
+
downtown_downstairs_00 data.start_idx=400 data.end_idx=500
|
| 76 |
+
downtown_downstairs_00 data.start_idx=500 data.end_idx=600
|
| 77 |
+
downtown_downstairs_00 data.start_idx=600 data.end_idx=700
|
| 78 |
+
downtown_downstairs_00 data.start_idx=700 data.end_idx=857
|
| 79 |
+
downtown_enterShop_00 data.start_idx=0 data.end_idx=100
|
| 80 |
+
downtown_enterShop_00 data.start_idx=100 data.end_idx=200
|
| 81 |
+
downtown_enterShop_00 data.start_idx=200 data.end_idx=300
|
| 82 |
+
downtown_enterShop_00 data.start_idx=300 data.end_idx=400
|
| 83 |
+
downtown_enterShop_00 data.start_idx=400 data.end_idx=500
|
| 84 |
+
downtown_enterShop_00 data.start_idx=500 data.end_idx=600
|
| 85 |
+
downtown_enterShop_00 data.start_idx=600 data.end_idx=700
|
| 86 |
+
downtown_enterShop_00 data.start_idx=700 data.end_idx=800
|
| 87 |
+
downtown_enterShop_00 data.start_idx=800 data.end_idx=900
|
| 88 |
+
downtown_enterShop_00 data.start_idx=900 data.end_idx=1000
|
| 89 |
+
downtown_enterShop_00 data.start_idx=1000 data.end_idx=1100
|
| 90 |
+
downtown_enterShop_00 data.start_idx=1100 data.end_idx=1200
|
| 91 |
+
downtown_enterShop_00 data.start_idx=1200 data.end_idx=1300
|
| 92 |
+
downtown_enterShop_00 data.start_idx=1300 data.end_idx=1449
|
| 93 |
+
downtown_rampAndStairs_00 data.start_idx=0 data.end_idx=100
|
| 94 |
+
downtown_rampAndStairs_00 data.start_idx=100 data.end_idx=200
|
| 95 |
+
downtown_rampAndStairs_00 data.start_idx=200 data.end_idx=300
|
| 96 |
+
downtown_rampAndStairs_00 data.start_idx=300 data.end_idx=400
|
| 97 |
+
downtown_rampAndStairs_00 data.start_idx=400 data.end_idx=500
|
| 98 |
+
downtown_rampAndStairs_00 data.start_idx=500 data.end_idx=600
|
| 99 |
+
downtown_rampAndStairs_00 data.start_idx=600 data.end_idx=700
|
| 100 |
+
downtown_rampAndStairs_00 data.start_idx=700 data.end_idx=800
|
| 101 |
+
downtown_rampAndStairs_00 data.start_idx=800 data.end_idx=984
|
| 102 |
+
downtown_runForBus_00 data.start_idx=0 data.end_idx=100
|
| 103 |
+
downtown_runForBus_00 data.start_idx=100 data.end_idx=200
|
| 104 |
+
downtown_runForBus_00 data.start_idx=200 data.end_idx=300
|
| 105 |
+
downtown_runForBus_00 data.start_idx=300 data.end_idx=400
|
| 106 |
+
downtown_runForBus_00 data.start_idx=400 data.end_idx=500
|
| 107 |
+
downtown_runForBus_00 data.start_idx=500 data.end_idx=600
|
| 108 |
+
downtown_runForBus_00 data.start_idx=600 data.end_idx=731
|
| 109 |
+
downtown_runForBus_01 data.start_idx=0 data.end_idx=100
|
| 110 |
+
downtown_runForBus_01 data.start_idx=100 data.end_idx=200
|
| 111 |
+
downtown_runForBus_01 data.start_idx=200 data.end_idx=300
|
| 112 |
+
downtown_runForBus_01 data.start_idx=300 data.end_idx=400
|
| 113 |
+
downtown_runForBus_01 data.start_idx=400 data.end_idx=500
|
| 114 |
+
downtown_runForBus_01 data.start_idx=500 data.end_idx=600
|
| 115 |
+
downtown_runForBus_01 data.start_idx=600 data.end_idx=783
|
| 116 |
+
downtown_sitOnStairs_00 data.start_idx=0 data.end_idx=100
|
| 117 |
+
downtown_sitOnStairs_00 data.start_idx=100 data.end_idx=200
|
| 118 |
+
downtown_sitOnStairs_00 data.start_idx=200 data.end_idx=300
|
| 119 |
+
downtown_sitOnStairs_00 data.start_idx=300 data.end_idx=400
|
| 120 |
+
downtown_sitOnStairs_00 data.start_idx=400 data.end_idx=500
|
| 121 |
+
downtown_sitOnStairs_00 data.start_idx=500 data.end_idx=600
|
| 122 |
+
downtown_sitOnStairs_00 data.start_idx=600 data.end_idx=700
|
| 123 |
+
downtown_sitOnStairs_00 data.start_idx=700 data.end_idx=800
|
| 124 |
+
downtown_sitOnStairs_00 data.start_idx=800 data.end_idx=900
|
| 125 |
+
downtown_sitOnStairs_00 data.start_idx=900 data.end_idx=1000
|
| 126 |
+
downtown_sitOnStairs_00 data.start_idx=1000 data.end_idx=1100
|
| 127 |
+
downtown_sitOnStairs_00 data.start_idx=1100 data.end_idx=1200
|
| 128 |
+
downtown_sitOnStairs_00 data.start_idx=1200 data.end_idx=1337
|
| 129 |
+
downtown_stairs_00 data.start_idx=0 data.end_idx=100
|
| 130 |
+
downtown_stairs_00 data.start_idx=100 data.end_idx=200
|
| 131 |
+
downtown_stairs_00 data.start_idx=200 data.end_idx=300
|
| 132 |
+
downtown_stairs_00 data.start_idx=300 data.end_idx=400
|
| 133 |
+
downtown_stairs_00 data.start_idx=400 data.end_idx=500
|
| 134 |
+
downtown_stairs_00 data.start_idx=500 data.end_idx=600
|
| 135 |
+
downtown_stairs_00 data.start_idx=600 data.end_idx=700
|
| 136 |
+
downtown_stairs_00 data.start_idx=700 data.end_idx=800
|
| 137 |
+
downtown_stairs_00 data.start_idx=800 data.end_idx=900
|
| 138 |
+
downtown_stairs_00 data.start_idx=900 data.end_idx=1000
|
| 139 |
+
downtown_stairs_00 data.start_idx=1000 data.end_idx=1100
|
| 140 |
+
downtown_stairs_00 data.start_idx=1100 data.end_idx=1240
|
| 141 |
+
downtown_upstairs_00 data.start_idx=0 data.end_idx=100
|
| 142 |
+
downtown_upstairs_00 data.start_idx=100 data.end_idx=200
|
| 143 |
+
downtown_upstairs_00 data.start_idx=200 data.end_idx=300
|
| 144 |
+
downtown_upstairs_00 data.start_idx=300 data.end_idx=400
|
| 145 |
+
downtown_upstairs_00 data.start_idx=400 data.end_idx=500
|
| 146 |
+
downtown_upstairs_00 data.start_idx=500 data.end_idx=600
|
| 147 |
+
downtown_upstairs_00 data.start_idx=600 data.end_idx=700
|
| 148 |
+
downtown_upstairs_00 data.start_idx=700 data.end_idx=845
|
| 149 |
+
downtown_walkBridge_01 data.start_idx=0 data.end_idx=100
|
| 150 |
+
downtown_walkBridge_01 data.start_idx=100 data.end_idx=200
|
| 151 |
+
downtown_walkBridge_01 data.start_idx=200 data.end_idx=300
|
| 152 |
+
downtown_walkBridge_01 data.start_idx=300 data.end_idx=400
|
| 153 |
+
downtown_walkBridge_01 data.start_idx=400 data.end_idx=500
|
| 154 |
+
downtown_walkBridge_01 data.start_idx=500 data.end_idx=600
|
| 155 |
+
downtown_walkBridge_01 data.start_idx=600 data.end_idx=700
|
| 156 |
+
downtown_walkBridge_01 data.start_idx=700 data.end_idx=800
|
| 157 |
+
downtown_walkBridge_01 data.start_idx=800 data.end_idx=900
|
| 158 |
+
downtown_walkBridge_01 data.start_idx=900 data.end_idx=1000
|
| 159 |
+
downtown_walkBridge_01 data.start_idx=1000 data.end_idx=1100
|
| 160 |
+
downtown_walkBridge_01 data.start_idx=1100 data.end_idx=1200
|
| 161 |
+
downtown_walkBridge_01 data.start_idx=1200 data.end_idx=1372
|
| 162 |
+
downtown_walkUphill_00 data.start_idx=0 data.end_idx=100
|
| 163 |
+
downtown_walkUphill_00 data.start_idx=100 data.end_idx=200
|
| 164 |
+
downtown_walkUphill_00 data.start_idx=200 data.end_idx=388
|
| 165 |
+
downtown_walking_00 data.start_idx=0 data.end_idx=100
|
| 166 |
+
downtown_walking_00 data.start_idx=100 data.end_idx=200
|
| 167 |
+
downtown_walking_00 data.start_idx=200 data.end_idx=300
|
| 168 |
+
downtown_walking_00 data.start_idx=300 data.end_idx=400
|
| 169 |
+
downtown_walking_00 data.start_idx=400 data.end_idx=500
|
| 170 |
+
downtown_walking_00 data.start_idx=500 data.end_idx=600
|
| 171 |
+
downtown_walking_00 data.start_idx=600 data.end_idx=700
|
| 172 |
+
downtown_walking_00 data.start_idx=700 data.end_idx=800
|
| 173 |
+
downtown_walking_00 data.start_idx=800 data.end_idx=900
|
| 174 |
+
downtown_walking_00 data.start_idx=900 data.end_idx=1000
|
| 175 |
+
downtown_walking_00 data.start_idx=1000 data.end_idx=1100
|
| 176 |
+
downtown_walking_00 data.start_idx=1100 data.end_idx=1200
|
| 177 |
+
downtown_walking_00 data.start_idx=1200 data.end_idx=1387
|
| 178 |
+
downtown_warmWelcome_00 data.start_idx=0 data.end_idx=100
|
| 179 |
+
downtown_warmWelcome_00 data.start_idx=100 data.end_idx=200
|
| 180 |
+
downtown_warmWelcome_00 data.start_idx=200 data.end_idx=300
|
| 181 |
+
downtown_warmWelcome_00 data.start_idx=300 data.end_idx=400
|
| 182 |
+
downtown_warmWelcome_00 data.start_idx=400 data.end_idx=589
|
| 183 |
+
downtown_weeklyMarket_00 data.start_idx=0 data.end_idx=100
|
| 184 |
+
downtown_weeklyMarket_00 data.start_idx=100 data.end_idx=200
|
| 185 |
+
downtown_weeklyMarket_00 data.start_idx=200 data.end_idx=300
|
| 186 |
+
downtown_weeklyMarket_00 data.start_idx=300 data.end_idx=400
|
| 187 |
+
downtown_weeklyMarket_00 data.start_idx=400 data.end_idx=500
|
| 188 |
+
downtown_weeklyMarket_00 data.start_idx=500 data.end_idx=600
|
| 189 |
+
downtown_weeklyMarket_00 data.start_idx=600 data.end_idx=700
|
| 190 |
+
downtown_weeklyMarket_00 data.start_idx=700 data.end_idx=800
|
| 191 |
+
downtown_weeklyMarket_00 data.start_idx=800 data.end_idx=900
|
| 192 |
+
downtown_weeklyMarket_00 data.start_idx=900 data.end_idx=1000
|
| 193 |
+
downtown_weeklyMarket_00 data.start_idx=1000 data.end_idx=1193
|
| 194 |
+
downtown_windowShopping_00 data.start_idx=0 data.end_idx=100
|
| 195 |
+
downtown_windowShopping_00 data.start_idx=100 data.end_idx=200
|
| 196 |
+
downtown_windowShopping_00 data.start_idx=200 data.end_idx=300
|
| 197 |
+
downtown_windowShopping_00 data.start_idx=300 data.end_idx=400
|
| 198 |
+
downtown_windowShopping_00 data.start_idx=400 data.end_idx=500
|
| 199 |
+
downtown_windowShopping_00 data.start_idx=500 data.end_idx=600
|
| 200 |
+
downtown_windowShopping_00 data.start_idx=600 data.end_idx=700
|
| 201 |
+
downtown_windowShopping_00 data.start_idx=700 data.end_idx=800
|
| 202 |
+
downtown_windowShopping_00 data.start_idx=800 data.end_idx=900
|
| 203 |
+
downtown_windowShopping_00 data.start_idx=900 data.end_idx=1000
|
| 204 |
+
downtown_windowShopping_00 data.start_idx=1000 data.end_idx=1100
|
| 205 |
+
downtown_windowShopping_00 data.start_idx=1100 data.end_idx=1200
|
| 206 |
+
downtown_windowShopping_00 data.start_idx=1200 data.end_idx=1300
|
| 207 |
+
downtown_windowShopping_00 data.start_idx=1300 data.end_idx=1400
|
| 208 |
+
downtown_windowShopping_00 data.start_idx=1400 data.end_idx=1500
|
| 209 |
+
downtown_windowShopping_00 data.start_idx=1500 data.end_idx=1600
|
| 210 |
+
downtown_windowShopping_00 data.start_idx=1600 data.end_idx=1700
|
| 211 |
+
downtown_windowShopping_00 data.start_idx=1700 data.end_idx=1800
|
| 212 |
+
downtown_windowShopping_00 data.start_idx=1800 data.end_idx=1948
|
| 213 |
+
flat_guitar_01 data.start_idx=0 data.end_idx=100
|
| 214 |
+
flat_guitar_01 data.start_idx=100 data.end_idx=200
|
| 215 |
+
flat_guitar_01 data.start_idx=200 data.end_idx=300
|
| 216 |
+
flat_guitar_01 data.start_idx=300 data.end_idx=400
|
| 217 |
+
flat_guitar_01 data.start_idx=400 data.end_idx=500
|
| 218 |
+
flat_guitar_01 data.start_idx=500 data.end_idx=600
|
| 219 |
+
flat_guitar_01 data.start_idx=600 data.end_idx=748
|
| 220 |
+
flat_packBags_00 data.start_idx=0 data.end_idx=100
|
| 221 |
+
flat_packBags_00 data.start_idx=100 data.end_idx=200
|
| 222 |
+
flat_packBags_00 data.start_idx=200 data.end_idx=300
|
| 223 |
+
flat_packBags_00 data.start_idx=300 data.end_idx=400
|
| 224 |
+
flat_packBags_00 data.start_idx=400 data.end_idx=500
|
| 225 |
+
flat_packBags_00 data.start_idx=500 data.end_idx=600
|
| 226 |
+
flat_packBags_00 data.start_idx=600 data.end_idx=700
|
| 227 |
+
flat_packBags_00 data.start_idx=700 data.end_idx=800
|
| 228 |
+
flat_packBags_00 data.start_idx=800 data.end_idx=900
|
| 229 |
+
flat_packBags_00 data.start_idx=900 data.end_idx=1000
|
| 230 |
+
flat_packBags_00 data.start_idx=1000 data.end_idx=1100
|
| 231 |
+
flat_packBags_00 data.start_idx=1100 data.end_idx=1279
|
| 232 |
+
office_phoneCall_00 data.start_idx=0 data.end_idx=100
|
| 233 |
+
office_phoneCall_00 data.start_idx=100 data.end_idx=200
|
| 234 |
+
office_phoneCall_00 data.start_idx=200 data.end_idx=300
|
| 235 |
+
office_phoneCall_00 data.start_idx=300 data.end_idx=400
|
| 236 |
+
office_phoneCall_00 data.start_idx=400 data.end_idx=500
|
| 237 |
+
office_phoneCall_00 data.start_idx=500 data.end_idx=600
|
| 238 |
+
office_phoneCall_00 data.start_idx=600 data.end_idx=700
|
| 239 |
+
office_phoneCall_00 data.start_idx=700 data.end_idx=880
|
| 240 |
+
outdoors_fencing_01 data.start_idx=0 data.end_idx=100
|
| 241 |
+
outdoors_fencing_01 data.start_idx=100 data.end_idx=200
|
| 242 |
+
outdoors_fencing_01 data.start_idx=200 data.end_idx=300
|
| 243 |
+
outdoors_fencing_01 data.start_idx=300 data.end_idx=400
|
| 244 |
+
outdoors_fencing_01 data.start_idx=400 data.end_idx=500
|
| 245 |
+
outdoors_fencing_01 data.start_idx=500 data.end_idx=600
|
| 246 |
+
outdoors_fencing_01 data.start_idx=600 data.end_idx=700
|
| 247 |
+
outdoors_fencing_01 data.start_idx=700 data.end_idx=800
|
| 248 |
+
outdoors_fencing_01 data.start_idx=800 data.end_idx=942
|
slahmr/slahmr/job_specs/davis.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
parkour data=davis fps=24
|
| 2 |
+
lady-running data=davis fps=24
|
| 3 |
+
dance-twirl data=davis fps=24
|
| 4 |
+
lindy-hop data=davis fps=24
|
| 5 |
+
hike data=davis fps=24
|
| 6 |
+
judo data=davis fps=24
|
| 7 |
+
lucia data=davis fps=24
|
| 8 |
+
tennis data=davis fps=24
|
| 9 |
+
skate-park data=davis fps=24
|
| 10 |
+
boxing-fisheye data=davis fps=24
|
| 11 |
+
crossing data=davis fps=24
|
| 12 |
+
loading data=davis fps=24
|
| 13 |
+
bike-packing data=davis fps=24
|
| 14 |
+
dance-jump data=davis fps=24
|
| 15 |
+
hockey data=davis fps=24
|
| 16 |
+
india data=davis fps=24
|
| 17 |
+
kid-football data=davis fps=24
|
| 18 |
+
longboard data=davis fps=24
|
| 19 |
+
schoolgirls data=davis fps=24
|
| 20 |
+
snowboard data=davis fps=24
|
| 21 |
+
stunt data=davis fps=24
|
| 22 |
+
swing data=davis fps=24
|
| 23 |
+
dancing data=davis fps=24
|
| 24 |
+
kite-walk data=davis fps=24
|