camenduru commited on
Commit
b4e342b
·
1 Parent(s): 599b29e

thanks to shubham-goel ❤

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. slahmr/.gitignore +140 -0
  3. slahmr/.gitmodules +6 -0
  4. slahmr/LICENSE +21 -0
  5. slahmr/README.md +167 -0
  6. slahmr/download_models.sh +5 -0
  7. slahmr/env.yaml +45 -0
  8. slahmr/env_build.yaml +127 -0
  9. slahmr/install.sh +35 -0
  10. slahmr/requirements.txt +27 -0
  11. slahmr/setup.py +9 -0
  12. slahmr/slahmr.zip +3 -0
  13. slahmr/slahmr/__init__.py +0 -0
  14. slahmr/slahmr/body_model/__init__.py +3 -0
  15. slahmr/slahmr/body_model/body_model.py +142 -0
  16. slahmr/slahmr/body_model/specs.py +554 -0
  17. slahmr/slahmr/body_model/utils.py +56 -0
  18. slahmr/slahmr/confs/config.yaml +51 -0
  19. slahmr/slahmr/confs/data/3dpw.yaml +18 -0
  20. slahmr/slahmr/confs/data/3dpw_gt.yaml +18 -0
  21. slahmr/slahmr/confs/data/custom.yaml +17 -0
  22. slahmr/slahmr/confs/data/davis.yaml +16 -0
  23. slahmr/slahmr/confs/data/egobody.yaml +18 -0
  24. slahmr/slahmr/confs/data/posetrack.yaml +17 -0
  25. slahmr/slahmr/confs/data/video.yaml +24 -0
  26. slahmr/slahmr/confs/init.yaml +13 -0
  27. slahmr/slahmr/confs/optim.yaml +51 -0
  28. slahmr/slahmr/data/__init__.py +2 -0
  29. slahmr/slahmr/data/dataset.py +438 -0
  30. slahmr/slahmr/data/tools.py +108 -0
  31. slahmr/slahmr/data/vidproc.py +82 -0
  32. slahmr/slahmr/eval/__init__.py +0 -0
  33. slahmr/slahmr/eval/associate.py +161 -0
  34. slahmr/slahmr/eval/egobody_utils.py +171 -0
  35. slahmr/slahmr/eval/run_eval.py +289 -0
  36. slahmr/slahmr/eval/split_3dpw.py +99 -0
  37. slahmr/slahmr/eval/split_egobody.py +123 -0
  38. slahmr/slahmr/eval/tools.py +181 -0
  39. slahmr/slahmr/geometry/__init__.py +5 -0
  40. slahmr/slahmr/geometry/camera.py +348 -0
  41. slahmr/slahmr/geometry/mesh.py +110 -0
  42. slahmr/slahmr/geometry/pcl.py +60 -0
  43. slahmr/slahmr/geometry/plane.py +101 -0
  44. slahmr/slahmr/geometry/rotation.py +284 -0
  45. slahmr/slahmr/humor/__init__.py +0 -0
  46. slahmr/slahmr/humor/amass_utils.py +148 -0
  47. slahmr/slahmr/humor/humor_model.py +1655 -0
  48. slahmr/slahmr/humor/transforms.py +472 -0
  49. slahmr/slahmr/job_specs/3dpw_test_split.txt +248 -0
  50. slahmr/slahmr/job_specs/davis.txt +24 -0
.gitattributes CHANGED
@@ -34,3 +34,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  lietorch-0.2-py3.10-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  lietorch-0.2-py3.10-linux-x86_64.egg filter=lfs diff=lfs merge=lfs -text
37
+ slahmr/teaser.png filter=lfs diff=lfs merge=lfs -text
38
+ slahmr/third-party/DROID-SLAM/build/lib.linux-x86_64-3.10/droid_backends.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
39
+ slahmr/third-party/DROID-SLAM/build/lib.linux-x86_64-3.10/lietorch_backends.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
40
+ slahmr/third-party/DROID-SLAM/build/temp.linux-x86_64-3.10/src/droid_kernels.o filter=lfs diff=lfs merge=lfs -text
41
+ slahmr/third-party/DROID-SLAM/build/temp.linux-x86_64-3.10/thirdparty/lietorch/lietorch/src/lietorch_gpu.o filter=lfs diff=lfs merge=lfs -text
42
+ slahmr/third-party/DROID-SLAM/thirdparty/lietorch/examples/registration/assets/registration.gif filter=lfs diff=lfs merge=lfs -text
43
+ slahmr/third-party/ViTPose/demo/resources/demo_coco.gif filter=lfs diff=lfs merge=lfs -text
slahmr/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # data
7
+ *outputs*
8
+ *renders*
9
+ *cache*
10
+ *checkpoints*
11
+ *_DATA
12
+
13
+ *.swp
14
+ *.out
15
+ *.err
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ pip-wheel-metadata/
35
+ share/python-wheels/
36
+ *.egg-info/
37
+ .installed.cfg
38
+ *.egg
39
+ MANIFEST
40
+
41
+ # PyInstaller
42
+ # Usually these files are written by a python script from a template
43
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
44
+ *.manifest
45
+ *.spec
46
+
47
+ # Installer logs
48
+ pip-log.txt
49
+ pip-delete-this-directory.txt
50
+
51
+ # Unit test / coverage reports
52
+ htmlcov/
53
+ .tox/
54
+ .nox/
55
+ .coverage
56
+ .coverage.*
57
+ .cache
58
+ nosetests.xml
59
+ coverage.xml
60
+ *.cover
61
+ *.py,cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ target/
87
+
88
+ # Jupyter Notebook
89
+ .ipynb_checkpoints
90
+
91
+ # IPython
92
+ profile_default/
93
+ ipython_config.py
94
+
95
+ # pyenv
96
+ .python-version
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ #Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
slahmr/.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "third-party/DROID-SLAM"]
2
+ path = third-party/DROID-SLAM
3
+ url = https://github.com/princeton-vl/DROID-SLAM.git
4
+ [submodule "third-party/ViTPose"]
5
+ path = third-party/ViTPose
6
+ url = https://github.com/ViTAE-Transformer/ViTPose.git
slahmr/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 vye16
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
slahmr/README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Decoupling Human and Camera Motion from Videos in the Wild
2
+
3
+ Official PyTorch implementation of the paper Decoupling Human and Camera Motion from Videos in the Wild
4
+
5
+ [Project page](https://vye16.github.io/slahmr/) | [ArXiv](https://arxiv.org/abs/2302.12827)
6
+
7
+ <img src="./teaser.png">
8
+
9
+ ## [<img src="https://i.imgur.com/QCojoJk.png" width="40"> You can run SLAHMR in Google Colab](https://colab.research.google.com/drive/1knzxW3XuxiaBH6hcwx01cs6DfA4azv5E?usp=sharing)
10
+
11
+ ## News
12
+
13
+ - [2023/07] We updated the code to support tracking from [4D Humans](https://shubham-goel.github.io/4dhumans/)! The original code remains in the `release` branch.
14
+ - [2023/02] Original release!
15
+
16
+ ## Getting started
17
+ This code was tested on Ubuntu 22.04 LTS and requires a CUDA-capable GPU.
18
+
19
+ 1. Clone repository and submodules
20
+ ```
21
+ git clone --recursive https://github.com/vye16/slahmr.git
22
+ ```
23
+ or initialize submodules if already cloned
24
+ ```
25
+ git submodule update --init --recursive
26
+ ```
27
+
28
+ 2. Set up conda environment. Run
29
+ ```
30
+ source install.sh
31
+ ```
32
+
33
+ <details>
34
+ <summary>We also include the following steps for trouble-shooting.</summary>
35
+
36
+ * Create environment
37
+ ```
38
+ conda env create -f env.yaml
39
+ conda activate slahmr
40
+ ```
41
+ We use PyTorch 1.13.0 with CUDA 11.7. Please modify according to your setup; we've tested successfully for PyTorch 1.11 as well.
42
+ We've also included `env_build.yaml` to speed up installation using already-solved dependencies, though it might not be compatible with your CUDA driver.
43
+
44
+ * Install PHALP
45
+ ```
46
+ pip install phalp[all]@git+https://github.com/brjathu/PHALP.git
47
+ ```
48
+
49
+ * Install current source repo
50
+ ```
51
+ pip install -e .
52
+ ```
53
+
54
+ * Install ViTPose
55
+ ```
56
+ pip install -v -e third-party/ViTPose
57
+ ```
58
+
59
+ * Install DROID-SLAM (will take a while)
60
+ ```
61
+ cd third-party/DROID-SLAM
62
+ python setup.py install
63
+ ```
64
+ </details>
65
+
66
+ 3. Download models from [here](https://drive.google.com/file/d/1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW/view?usp=sharing). Run
67
+ ```
68
+ ./download_models.sh
69
+ ```
70
+ or
71
+ ```
72
+ gdown https://drive.google.com/uc?id=1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW
73
+ unzip -q slahmr_dependencies.zip
74
+ rm slahmr_dependencies.zip
75
+ ```
76
+
77
+ All models and checkpoints should have been unpacked in `_DATA`.
78
+
79
+
80
+ ## Fitting to an RGB video:
81
+ For a custom video, you can edit the config file: `slahmr/confs/data/video.yaml`.
82
+ Then, from the `slahmr` directory, you can run:
83
+ ```
84
+ python run_opt.py data=video run_opt=True run_vis=True
85
+ ```
86
+
87
+ We use hydra to launch experiments, and all parameters can be found in `slahmr/confs/config.yaml`.
88
+ If you would like to update any aspect of logging or optimization tuning, update the relevant config files.
89
+
90
+ By default, we will log each run to `outputs/video-val/<DATE>/<VIDEO_NAME>`.
91
+ Each stage of optimization will produce a separate subdirectory, each of which will contain outputs saved throughout the optimization
92
+ and rendered videos of the final result for that stage of optimization.
93
+ The `motion_chunks` directory contains the outputs of the final stage of optimization,
94
+ `root_fit` and `smooth_fit` contain outputs of short, intermediate stages of optimization,
95
+ and `init` contains the initialized outputs before optimization.
96
+
97
+ We've provided a `run_vis.py` script for running visualization from logs after optimization.
98
+ From the `slahmr` directory, run
99
+ ```
100
+ python run_vis.py --log_root <LOG_ROOT>
101
+ ```
102
+ and it will visualize all log subdirectories in `<LOG_ROOT>`.
103
+ Each output npz file will contain the SMPL parameters for all optimized people, the camera intrinsics and extrinsics.
104
+ The `motion_chunks` output will contain additional predictions from the motion prior.
105
+ Please see `run_vis.py` for how to extract the people meshes from the output parameters.
106
+
107
+
108
+ ## Fitting to specific datasets:
109
+ We provide configurations for dataset formats in `slahmr/confs/data`:
110
+ 1. Posetrack in `slahmr/confs/data/posetrack.yaml`
111
+ 2. Egobody in `slahmr/confs/data/egobody.yaml`
112
+ 3. 3DPW in `slahmr/confs/data/3dpw.yaml`
113
+ 4. Custom video in `slahmr/confs/data/video.yaml`
114
+
115
+ **Please make sure to update all paths to data in the config files.**
116
+
117
+ We include tools to both process existing datasets we evaluated on in the paper, and to process custom data and videos.
118
+ We include experiments from the paper on the Egobody, Posetrack, and 3DPW datasets.
119
+
120
+ If you want to run on a large number of videos, or if you want to select specific people tracks for optimization,
121
+ we recommend preprocesing in advance.
122
+ For a single downloaded video, there is no need to run preprocessing in advance.
123
+
124
+ From the `slahmr/preproc` directory, run PHALP on all your sequences
125
+ ```
126
+ python launch_phalp.py --type <DATASET_TYPE> --root <DATASET_ROOT> --split <DATASET_SPLIT> --gpus <GPUS>
127
+ ```
128
+ and run DROID-SLAM on all your sequences
129
+ ```
130
+ python launch_slam.py --type <DATASET_TYPE> --root <DATASET_ROOT> --split <DATASET_SPLIT> --gpus <GPUS>
131
+ ```
132
+ You can also update the paths to datasets in `slahmr/preproc/datasets.py` for repeated use.
133
+
134
+ Then, from the `slahmr` directory,
135
+ ```
136
+ python run_opt.py data=<DATA_CFG> run_opt=True run_vis=True
137
+ ```
138
+
139
+ We've provided a helper script `launch.py` for launching many optimization jobs in parallel.
140
+ You can specify job-specific arguments with a job spec file, such as the example files in `job_specs`,
141
+ and batch-specific arguments shared across all jobs as
142
+ ```
143
+ python launch.py --gpus 1 2 -f job_specs/pt_val_shots.txt -s data=posetrack exp_name=posetrack_val
144
+ ```
145
+
146
+ ## Evaluation on 3D datasets
147
+ After launching and completing optimization on either the Egobody or 3DPW datasets,
148
+ you can evaluate the outputs with scripts in the `eval` directory.
149
+ Before running, please update `EGOBODY_ROOT` and `TDPW_ROOT` in `eval/tools.py`.
150
+ Then, run
151
+ ```
152
+ python run_eval.py -d <DSET_TYPE> -i <RES_ROOT> -f <JOB_FILE>
153
+ ```
154
+ where `<JOB_FILE>` is the same job file used to launch all optimization runs.
155
+
156
+
157
+ ## BibTeX
158
+
159
+ If you use our code in your research, please cite the following paper:
160
+ ```
161
+ @inproceedings{ye2023slahmr,
162
+ title={Decoupling Human and Camera Motion from Videos in the Wild},
163
+ author={Ye, Vickie and Pavlakos, Georgios and Malik, Jitendra and Kanazawa, Angjoo},
164
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
165
+ month={June},
166
+ year={2023}
167
+ }
slahmr/download_models.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+ # download models
3
+ gdown https://drive.google.com/uc?id=1GXAd-45GzGYNENKgQxFQ4PHrBp8wDRlW
4
+ unzip -q slahmr_dependencies.zip
5
+ rm slahmr_dependencies.zip
slahmr/env.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: slahmr
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - nvidia
6
+ - rusty1s
7
+ dependencies:
8
+ - python=3.9
9
+ - pytorch
10
+ - pytorch-cuda=11.7
11
+ - torchvision
12
+ - pytorch-scatter
13
+ - suitesparse
14
+ - pip
15
+ - pip:
16
+ - git+https://github.com/facebookresearch/detectron2.git
17
+ - git+https://github.com/brjathu/pytube.git
18
+ - git+https://github.com/nghorbani/configer
19
+ - setuptools==59.5.0
20
+ - torchgeometry==0.1.2
21
+ - tensorboard
22
+ - smplx
23
+ - pyrender
24
+ - open3d
25
+ - imageio-ffmpeg
26
+ - matplotlib
27
+ - opencv-python
28
+ - scipy
29
+ - scikit-image
30
+ - scikit-learn==0.22
31
+ - joblib
32
+ - cython
33
+ - tqdm
34
+ - hydra-core
35
+ - pyyaml
36
+ - chumpy
37
+ - gdown
38
+ - dill
39
+ - motmetrics
40
+ - scenedetect[opencv]
41
+ - einops
42
+ - mmcv==1.3.9
43
+ - timm==0.4.9
44
+ - xtcocotools==1.10
45
+ - pandas==1.4.0
slahmr/env_build.yaml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: slahmr2
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - bzip2=1.0.8=h7b6447c_0
8
+ - ca-certificates=2023.05.30=h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_0
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgomp=11.2.0=h1234567_1
13
+ - libstdcxx-ng=11.2.0=h1234567_1
14
+ - libuuid=1.41.5=h5eee18b_0
15
+ - ncurses=6.4=h6a678d5_0
16
+ - openssl=3.0.9=h7f8727e_0
17
+ - python=3.10.12=h955ad1f_0
18
+ - readline=8.2=h5eee18b_0
19
+ - sqlite=3.41.2=h5eee18b_0
20
+ - tk=8.6.12=h1ccaba5_0
21
+ - tzdata=2023c=h04d1e81_0
22
+ - xz=5.4.2=h5eee18b_0
23
+ - zlib=1.2.13=h5eee18b_0
24
+ - pip:
25
+ - addict==2.4.0
26
+ - ansi2html==1.8.0
27
+ - appdirs==1.4.4
28
+ - asttokens==2.2.1
29
+ - attrs==23.1.0
30
+ - av==10.0.0
31
+ - backcall==0.2.0
32
+ - beautifulsoup4==4.12.2
33
+ - certifi==2022.12.7
34
+ - charset-normalizer==2.1.1
35
+ - click==8.1.4
36
+ - comm==0.1.3
37
+ - configargparse==1.5.5
38
+ - configer==1.4.1
39
+ - configparser==5.3.0
40
+ - contourpy==1.1.0
41
+ - cycler==0.11.0
42
+ - cython==0.29.36
43
+ - dash==2.11.1
44
+ - dash-core-components==2.0.0
45
+ - dash-html-components==2.0.0
46
+ - dash-table==5.0.0
47
+ - debugpy==1.6.7
48
+ - decorator==5.1.1
49
+ - droid-backends==0.0.0
50
+ - executing==1.2.0
51
+ - fastjsonschema==2.17.1
52
+ - flask==2.2.5
53
+ - fonttools==4.40.0
54
+ - gdown==4.7.1
55
+ - hmr2==0.0.0
56
+ - idna==3.4
57
+ - imageio-ffmpeg==0.4.8
58
+ - importlib-metadata==6.7.0
59
+ - ipdb==0.13.13
60
+ - ipykernel==6.24.0
61
+ - ipython==8.14.0
62
+ - ipywidgets==8.0.7
63
+ - itsdangerous==2.1.2
64
+ - jedi==0.18.2
65
+ - jinja2==3.1.2
66
+ - json-tricks==3.17.1
67
+ - jsonschema==4.18.0
68
+ - jsonschema-specifications==2023.6.1
69
+ - jupyter-client==8.3.0
70
+ - jupyter-core==5.3.1
71
+ - jupyterlab-widgets==3.0.8
72
+ - kiwisolver==1.4.4
73
+ - lietorch==0.2
74
+ - matplotlib==3.7.2
75
+ - matplotlib-inline==0.1.6
76
+ - mmcv==1.3.9
77
+ - munkres==1.1.4
78
+ - nbformat==5.7.0
79
+ - nest-asyncio==1.5.6
80
+ - oauthlib==3.2.2
81
+ - open3d==0.17.0
82
+ - pandas==1.4.0
83
+ - parso==0.8.3
84
+ - pexpect==4.8.0
85
+ - phalp==0.1.3
86
+ - pickleshare==0.7.5
87
+ - pillow==9.3.0
88
+ - pip==23.1.2
89
+ - plotly==5.15.0
90
+ - prompt-toolkit==3.0.39
91
+ - psutil==5.9.5
92
+ - ptyprocess==0.7.0
93
+ - pure-eval==0.2.2
94
+ - pyasn1==0.5.0
95
+ - pyasn1-modules==0.3.0
96
+ - pyparsing==3.0.9
97
+ - pyquaternion==0.9.9
98
+ - pysocks==1.7.1
99
+ - pytz==2023.3
100
+ - pyyaml==6.0
101
+ - pyzmq==25.1.0
102
+ - referencing==0.29.1
103
+ - requests==2.28.1
104
+ - retrying==1.3.4
105
+ - rpds-py==0.8.8
106
+ - scikit-learn==1.3.0
107
+ - scipy==1.11.1
108
+ - setuptools==59.5.0
109
+ - six==1.16.0
110
+ - soupsieve==2.4.1
111
+ - stack-data==0.6.2
112
+ - tenacity==8.2.2
113
+ - timm==0.4.9
114
+ - torch==1.13.0+cu117
115
+ - torch-scatter==2.1.1+pt113cu117
116
+ - torchgeometry==0.1.2
117
+ - torchvision==0.14.0+cu117
118
+ - tornado==6.3.2
119
+ - traitlets==5.9.0
120
+ - urllib3==1.26.13
121
+ - wcwidth==0.2.6
122
+ - werkzeug==2.2.3
123
+ - wheel==0.38.4
124
+ - widgetsnbextension==4.0.8
125
+ - xtcocotools==1.13
126
+ - yapf==0.40.1
127
+ - zipp==3.15.0
slahmr/install.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+
4
+ export CONDA_ENV_NAME=slahmr
5
+
6
+ conda create -n $CONDA_ENV_NAME python=3.10 -y
7
+
8
+ conda activate $CONDA_ENV_NAME
9
+
10
+ # install pytorch using pip, update with appropriate cuda drivers if necessary
11
+ pip install torch==1.13.0 torchvision==0.14.0 --index-url https://download.pytorch.org/whl/cu117
12
+ # uncomment if pip installation isn't working
13
+ # conda install pytorch=1.13.0 torchvision=0.14.0 pytorch-cuda=11.7 -c pytorch -c nvidia -y
14
+
15
+ # install pytorch scatter using pip, update with appropriate cuda drivers if necessary
16
+ pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cu117.html
17
+ # uncomment if pip installation isn't working
18
+ # conda install pytorch-scatter -c pyg -y
19
+
20
+ # install PHALP
21
+ pip install phalp[all]@git+https://github.com/brjathu/PHALP.git
22
+
23
+ # install remaining requirements
24
+ pip install -r requirements.txt
25
+
26
+ # install source
27
+ pip install -e .
28
+
29
+ # install ViTPose
30
+ pip install -v -e third-party/ViTPose
31
+
32
+ # install DROID-SLAM
33
+ cd third-party/DROID-SLAM
34
+ python setup.py install
35
+ cd ../..
slahmr/requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/nghorbani/configer
2
+ setuptools==59.5.0
3
+ torchgeometry==0.1.2
4
+ tensorboard
5
+ numpy==1.23
6
+ smplx
7
+ pyrender
8
+ open3d
9
+ imageio-ffmpeg
10
+ matplotlib
11
+ opencv-python
12
+ scipy
13
+ scikit-image
14
+ joblib
15
+ cython
16
+ tqdm
17
+ hydra-core
18
+ pyyaml
19
+ chumpy
20
+ gdown
21
+ dill
22
+ motmetrics
23
+ einops
24
+ mmcv==1.3.9
25
+ timm==0.4.9
26
+ xtcocotools
27
+ pandas==1.4.0
slahmr/setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="slahmr",
5
+ packages=find_packages(
6
+ where="slahmr",
7
+ ),
8
+ package_dir={"": "slahmr"},
9
+ )
slahmr/slahmr.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e59655cc88dc71f3a9ac20425310bf996da3b6f15619497d17c67e7b60a671a
3
+ size 5662244525
slahmr/slahmr/__init__.py ADDED
File without changes
slahmr/slahmr/body_model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .body_model import *
2
+ from .specs import *
3
+ from .utils import *
slahmr/slahmr/body_model/body_model.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from smplx import SMPL, SMPLH, SMPLX
7
+ from smplx.vertex_ids import vertex_ids
8
+ from smplx.utils import Struct
9
+
10
+
11
+ class BodyModel(nn.Module):
12
+ """
13
+ Wrapper around SMPLX body model class.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ bm_path,
19
+ num_betas=10,
20
+ batch_size=1,
21
+ num_expressions=10,
22
+ use_vtx_selector=False,
23
+ model_type="smplh",
24
+ kid_template_path=None,
25
+ ):
26
+ super(BodyModel, self).__init__()
27
+ """
28
+ Creates the body model object at the given path.
29
+
30
+ :param bm_path: path to the body model pkl file
31
+ :param num_expressions: only for smplx
32
+ :param model_type: one of [smpl, smplh, smplx]
33
+ :param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints
34
+ """
35
+ self.use_vtx_selector = use_vtx_selector
36
+ cur_vertex_ids = None
37
+ if self.use_vtx_selector:
38
+ cur_vertex_ids = vertex_ids[model_type]
39
+ data_struct = None
40
+ if ".npz" in bm_path:
41
+ # smplx does not support .npz by default, so have to load in manually
42
+ smpl_dict = np.load(bm_path, encoding="latin1")
43
+ data_struct = Struct(**smpl_dict)
44
+ # print(smpl_dict.files)
45
+ if model_type == "smplh":
46
+ data_struct.hands_componentsl = np.zeros((0))
47
+ data_struct.hands_componentsr = np.zeros((0))
48
+ data_struct.hands_meanl = np.zeros((15 * 3))
49
+ data_struct.hands_meanr = np.zeros((15 * 3))
50
+ V, D, B = data_struct.shapedirs.shape
51
+ data_struct.shapedirs = np.concatenate(
52
+ [data_struct.shapedirs, np.zeros((V, D, SMPL.SHAPE_SPACE_DIM - B))],
53
+ axis=-1,
54
+ ) # super hacky way to let smplh use 16-size beta
55
+ kwargs = {
56
+ "model_type": model_type,
57
+ "data_struct": data_struct,
58
+ "num_betas": num_betas,
59
+ "batch_size": batch_size,
60
+ "num_expression_coeffs": num_expressions,
61
+ "vertex_ids": cur_vertex_ids,
62
+ "use_pca": False,
63
+ "flat_hand_mean": False,
64
+ }
65
+ if kid_template_path is not None:
66
+ kwargs["kid_template_path"] = kid_template_path
67
+ kwargs["age"] = "kid"
68
+
69
+ assert model_type in ["smpl", "smplh", "smplx"]
70
+ if model_type == "smpl":
71
+ self.bm = SMPL(bm_path, **kwargs)
72
+ self.num_joints = SMPL.NUM_JOINTS
73
+ elif model_type == "smplh":
74
+ self.bm = SMPLH(bm_path, **kwargs)
75
+ self.num_joints = SMPLH.NUM_JOINTS
76
+ elif model_type == "smplx":
77
+ self.bm = SMPLX(bm_path, **kwargs)
78
+ self.num_joints = SMPLX.NUM_JOINTS
79
+
80
+ self.model_type = model_type
81
+
82
+ def forward(
83
+ self,
84
+ root_orient=None,
85
+ pose_body=None,
86
+ pose_hand=None,
87
+ pose_jaw=None,
88
+ pose_eye=None,
89
+ betas=None,
90
+ trans=None,
91
+ dmpls=None,
92
+ expression=None,
93
+ return_dict=False,
94
+ **kwargs
95
+ ):
96
+ """
97
+ Note dmpls are not supported.
98
+ """
99
+ assert dmpls is None
100
+ out_obj = self.bm(
101
+ betas=betas,
102
+ global_orient=root_orient,
103
+ body_pose=pose_body,
104
+ left_hand_pose=None
105
+ if pose_hand is None
106
+ else pose_hand[:, : (SMPLH.NUM_HAND_JOINTS * 3)],
107
+ right_hand_pose=None
108
+ if pose_hand is None
109
+ else pose_hand[:, (SMPLH.NUM_HAND_JOINTS * 3) :],
110
+ transl=trans,
111
+ expression=expression,
112
+ jaw_pose=pose_jaw,
113
+ leye_pose=None if pose_eye is None else pose_eye[:, :3],
114
+ reye_pose=None if pose_eye is None else pose_eye[:, 3:],
115
+ return_full_pose=True,
116
+ **kwargs
117
+ )
118
+
119
+ out = {
120
+ "v": out_obj.vertices,
121
+ "f": self.bm.faces_tensor,
122
+ "betas": out_obj.betas,
123
+ "Jtr": out_obj.joints,
124
+ "pose_body": out_obj.body_pose,
125
+ "full_pose": out_obj.full_pose,
126
+ }
127
+ if self.model_type in ["smplh", "smplx"]:
128
+ out["pose_hand"] = torch.cat(
129
+ [out_obj.left_hand_pose, out_obj.right_hand_pose], dim=-1
130
+ )
131
+ if self.model_type == "smplx":
132
+ out["pose_jaw"] = out_obj.jaw_pose
133
+ out["pose_eye"] = pose_eye
134
+
135
+ if not self.use_vtx_selector:
136
+ # don't need extra joints
137
+ out["Jtr"] = out["Jtr"][:, : self.num_joints + 1] # add one for the root
138
+
139
+ if not return_dict:
140
+ out = Struct(**out)
141
+
142
+ return out
slahmr/slahmr/body_model/specs.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ SMPL_JOINTS = {
4
+ "hips": 0,
5
+ "leftUpLeg": 1,
6
+ "rightUpLeg": 2,
7
+ "spine": 3,
8
+ "leftLeg": 4,
9
+ "rightLeg": 5,
10
+ "spine1": 6,
11
+ "leftFoot": 7,
12
+ "rightFoot": 8,
13
+ "spine2": 9,
14
+ "leftToeBase": 10,
15
+ "rightToeBase": 11,
16
+ "neck": 12,
17
+ "leftShoulder": 13,
18
+ "rightShoulder": 14,
19
+ "head": 15,
20
+ "leftArm": 16,
21
+ "rightArm": 17,
22
+ "leftForeArm": 18,
23
+ "rightForeArm": 19,
24
+ "leftHand": 20,
25
+ "rightHand": 21,
26
+ }
27
+ SMPL_PARENTS = [
28
+ -1,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 1,
33
+ 2,
34
+ 3,
35
+ 4,
36
+ 5,
37
+ 6,
38
+ 7,
39
+ 8,
40
+ 9,
41
+ 12,
42
+ 12,
43
+ 12,
44
+ 13,
45
+ 14,
46
+ 16,
47
+ 17,
48
+ 18,
49
+ 19,
50
+ ]
51
+
52
+ SMPLH_PATH = "./body_models/smplh"
53
+ SMPLX_PATH = "./body_models/smplx"
54
+ SMPL_PATH = "./body_models/smpl"
55
+ VPOSER_PATH = "./body_models/vposer_v1_0"
56
+
57
+ # chosen virtual mocap markers that are "keypoints" to work with
58
+ KEYPT_VERTS = [
59
+ 4404,
60
+ 920,
61
+ 3076,
62
+ 3169,
63
+ 823,
64
+ 4310,
65
+ 1010,
66
+ 1085,
67
+ 4495,
68
+ 4569,
69
+ 6615,
70
+ 3217,
71
+ 3313,
72
+ 6713,
73
+ 6785,
74
+ 3383,
75
+ 6607,
76
+ 3207,
77
+ 1241,
78
+ 1508,
79
+ 4797,
80
+ 4122,
81
+ 1618,
82
+ 1569,
83
+ 5135,
84
+ 5040,
85
+ 5691,
86
+ 5636,
87
+ 5404,
88
+ 2230,
89
+ 2173,
90
+ 2108,
91
+ 134,
92
+ 3645,
93
+ 6543,
94
+ 3123,
95
+ 3024,
96
+ 4194,
97
+ 1306,
98
+ 182,
99
+ 3694,
100
+ 4294,
101
+ 744,
102
+ ]
103
+
104
+
105
+ """
106
+ Openpose
107
+ """
108
+ OP_NUM_JOINTS = 25
109
+ # OP_IGNORE_JOINTS = [1, 9, 12] # neck and left/right hip
110
+ OP_IGNORE_JOINTS = [1] # neck
111
+ OP_EDGE_LIST = [
112
+ [1, 8],
113
+ [1, 2],
114
+ [1, 5],
115
+ [2, 3],
116
+ [3, 4],
117
+ [5, 6],
118
+ [6, 7],
119
+ [8, 9],
120
+ [9, 10],
121
+ [10, 11],
122
+ [8, 12],
123
+ [12, 13],
124
+ [13, 14],
125
+ [1, 0],
126
+ [0, 15],
127
+ [15, 17],
128
+ [0, 16],
129
+ [16, 18],
130
+ [14, 19],
131
+ [19, 20],
132
+ [14, 21],
133
+ [11, 22],
134
+ [22, 23],
135
+ [11, 24],
136
+ ]
137
+ # indices to map an openpose detection to its flipped version
138
+ OP_FLIP_MAP = [
139
+ 0,
140
+ 1,
141
+ 5,
142
+ 6,
143
+ 7,
144
+ 2,
145
+ 3,
146
+ 4,
147
+ 8,
148
+ 12,
149
+ 13,
150
+ 14,
151
+ 9,
152
+ 10,
153
+ 11,
154
+ 16,
155
+ 15,
156
+ 18,
157
+ 17,
158
+ 22,
159
+ 23,
160
+ 24,
161
+ 19,
162
+ 20,
163
+ 21,
164
+ ]
165
+
166
+
167
+ #
168
+ # From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py
169
+ # Please see license for usage restrictions.
170
+ #
171
+ def smpl_to_openpose(
172
+ model_type="smplx",
173
+ use_hands=True,
174
+ use_face=True,
175
+ use_face_contour=False,
176
+ openpose_format="coco25",
177
+ ):
178
+ """Returns the indices of the permutation that maps SMPL to OpenPose
179
+
180
+ Parameters
181
+ ----------
182
+ model_type: str, optional
183
+ The type of SMPL-like model that is used. The default mapping
184
+ returned is for the SMPLX model
185
+ use_hands: bool, optional
186
+ Flag for adding to the returned permutation the mapping for the
187
+ hand keypoints. Defaults to True
188
+ use_face: bool, optional
189
+ Flag for adding to the returned permutation the mapping for the
190
+ face keypoints. Defaults to True
191
+ use_face_contour: bool, optional
192
+ Flag for appending the facial contour keypoints. Defaults to False
193
+ openpose_format: bool, optional
194
+ The output format of OpenPose. For now only COCO-25 and COCO-19 is
195
+ supported. Defaults to 'coco25'
196
+
197
+ """
198
+ if openpose_format.lower() == "coco25":
199
+ if model_type == "smpl":
200
+ return np.array(
201
+ [
202
+ 24,
203
+ 12,
204
+ 17,
205
+ 19,
206
+ 21,
207
+ 16,
208
+ 18,
209
+ 20,
210
+ 0,
211
+ 2,
212
+ 5,
213
+ 8,
214
+ 1,
215
+ 4,
216
+ 7,
217
+ 25,
218
+ 26,
219
+ 27,
220
+ 28,
221
+ 29,
222
+ 30,
223
+ 31,
224
+ 32,
225
+ 33,
226
+ 34,
227
+ ],
228
+ dtype=np.int32,
229
+ )
230
+ elif model_type == "smplh":
231
+ body_mapping = np.array(
232
+ [
233
+ 52,
234
+ 12,
235
+ 17,
236
+ 19,
237
+ 21,
238
+ 16,
239
+ 18,
240
+ 20,
241
+ 0,
242
+ 2,
243
+ 5,
244
+ 8,
245
+ 1,
246
+ 4,
247
+ 7,
248
+ 53,
249
+ 54,
250
+ 55,
251
+ 56,
252
+ 57,
253
+ 58,
254
+ 59,
255
+ 60,
256
+ 61,
257
+ 62,
258
+ ],
259
+ dtype=np.int32,
260
+ )
261
+ mapping = [body_mapping]
262
+ if use_hands:
263
+ lhand_mapping = np.array(
264
+ [
265
+ 20,
266
+ 34,
267
+ 35,
268
+ 36,
269
+ 63,
270
+ 22,
271
+ 23,
272
+ 24,
273
+ 64,
274
+ 25,
275
+ 26,
276
+ 27,
277
+ 65,
278
+ 31,
279
+ 32,
280
+ 33,
281
+ 66,
282
+ 28,
283
+ 29,
284
+ 30,
285
+ 67,
286
+ ],
287
+ dtype=np.int32,
288
+ )
289
+ rhand_mapping = np.array(
290
+ [
291
+ 21,
292
+ 49,
293
+ 50,
294
+ 51,
295
+ 68,
296
+ 37,
297
+ 38,
298
+ 39,
299
+ 69,
300
+ 40,
301
+ 41,
302
+ 42,
303
+ 70,
304
+ 46,
305
+ 47,
306
+ 48,
307
+ 71,
308
+ 43,
309
+ 44,
310
+ 45,
311
+ 72,
312
+ ],
313
+ dtype=np.int32,
314
+ )
315
+ mapping += [lhand_mapping, rhand_mapping]
316
+ return np.concatenate(mapping)
317
+ # SMPLX
318
+ elif model_type == "smplx":
319
+ body_mapping = np.array(
320
+ [
321
+ 55,
322
+ 12,
323
+ 17,
324
+ 19,
325
+ 21,
326
+ 16,
327
+ 18,
328
+ 20,
329
+ 0,
330
+ 2,
331
+ 5,
332
+ 8,
333
+ 1,
334
+ 4,
335
+ 7,
336
+ 56,
337
+ 57,
338
+ 58,
339
+ 59,
340
+ 60,
341
+ 61,
342
+ 62,
343
+ 63,
344
+ 64,
345
+ 65,
346
+ ],
347
+ dtype=np.int32,
348
+ )
349
+ mapping = [body_mapping]
350
+ if use_hands:
351
+ lhand_mapping = np.array(
352
+ [
353
+ 20,
354
+ 37,
355
+ 38,
356
+ 39,
357
+ 66,
358
+ 25,
359
+ 26,
360
+ 27,
361
+ 67,
362
+ 28,
363
+ 29,
364
+ 30,
365
+ 68,
366
+ 34,
367
+ 35,
368
+ 36,
369
+ 69,
370
+ 31,
371
+ 32,
372
+ 33,
373
+ 70,
374
+ ],
375
+ dtype=np.int32,
376
+ )
377
+ rhand_mapping = np.array(
378
+ [
379
+ 21,
380
+ 52,
381
+ 53,
382
+ 54,
383
+ 71,
384
+ 40,
385
+ 41,
386
+ 42,
387
+ 72,
388
+ 43,
389
+ 44,
390
+ 45,
391
+ 73,
392
+ 49,
393
+ 50,
394
+ 51,
395
+ 74,
396
+ 46,
397
+ 47,
398
+ 48,
399
+ 75,
400
+ ],
401
+ dtype=np.int32,
402
+ )
403
+
404
+ mapping += [lhand_mapping, rhand_mapping]
405
+ if use_face:
406
+ # end_idx = 127 + 17 * use_face_contour
407
+ face_mapping = np.arange(
408
+ 76, 127 + 17 * use_face_contour, dtype=np.int32
409
+ )
410
+ mapping += [face_mapping]
411
+
412
+ return np.concatenate(mapping)
413
+ else:
414
+ raise ValueError("Unknown model type: {}".format(model_type))
415
+ elif openpose_format == "coco19":
416
+ if model_type == "smpl":
417
+ return np.array(
418
+ [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28],
419
+ dtype=np.int32,
420
+ )
421
+ elif model_type == "smplh":
422
+ body_mapping = np.array(
423
+ [52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 53, 54, 55, 56],
424
+ dtype=np.int32,
425
+ )
426
+ mapping = [body_mapping]
427
+ if use_hands:
428
+ lhand_mapping = np.array(
429
+ [
430
+ 20,
431
+ 34,
432
+ 35,
433
+ 36,
434
+ 57,
435
+ 22,
436
+ 23,
437
+ 24,
438
+ 58,
439
+ 25,
440
+ 26,
441
+ 27,
442
+ 59,
443
+ 31,
444
+ 32,
445
+ 33,
446
+ 60,
447
+ 28,
448
+ 29,
449
+ 30,
450
+ 61,
451
+ ],
452
+ dtype=np.int32,
453
+ )
454
+ rhand_mapping = np.array(
455
+ [
456
+ 21,
457
+ 49,
458
+ 50,
459
+ 51,
460
+ 62,
461
+ 37,
462
+ 38,
463
+ 39,
464
+ 63,
465
+ 40,
466
+ 41,
467
+ 42,
468
+ 64,
469
+ 46,
470
+ 47,
471
+ 48,
472
+ 65,
473
+ 43,
474
+ 44,
475
+ 45,
476
+ 66,
477
+ ],
478
+ dtype=np.int32,
479
+ )
480
+ mapping += [lhand_mapping, rhand_mapping]
481
+ return np.concatenate(mapping)
482
+ # SMPLX
483
+ elif model_type == "smplx":
484
+ body_mapping = np.array(
485
+ [55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 56, 57, 58, 59],
486
+ dtype=np.int32,
487
+ )
488
+ mapping = [body_mapping]
489
+ if use_hands:
490
+ lhand_mapping = np.array(
491
+ [
492
+ 20,
493
+ 37,
494
+ 38,
495
+ 39,
496
+ 60,
497
+ 25,
498
+ 26,
499
+ 27,
500
+ 61,
501
+ 28,
502
+ 29,
503
+ 30,
504
+ 62,
505
+ 34,
506
+ 35,
507
+ 36,
508
+ 63,
509
+ 31,
510
+ 32,
511
+ 33,
512
+ 64,
513
+ ],
514
+ dtype=np.int32,
515
+ )
516
+ rhand_mapping = np.array(
517
+ [
518
+ 21,
519
+ 52,
520
+ 53,
521
+ 54,
522
+ 65,
523
+ 40,
524
+ 41,
525
+ 42,
526
+ 66,
527
+ 43,
528
+ 44,
529
+ 45,
530
+ 67,
531
+ 49,
532
+ 50,
533
+ 51,
534
+ 68,
535
+ 46,
536
+ 47,
537
+ 48,
538
+ 69,
539
+ ],
540
+ dtype=np.int32,
541
+ )
542
+
543
+ mapping += [lhand_mapping, rhand_mapping]
544
+ if use_face:
545
+ face_mapping = np.arange(
546
+ 70, 70 + 51 + 17 * use_face_contour, dtype=np.int32
547
+ )
548
+ mapping += [face_mapping]
549
+
550
+ return np.concatenate(mapping)
551
+ else:
552
+ raise ValueError("Unknown model type: {}".format(model_type))
553
+ else:
554
+ raise ValueError("Unknown joint format: {}".format(openpose_format))
slahmr/slahmr/body_model/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .specs import SMPL_JOINTS
3
+
4
+
5
+ def run_smpl(body_model, trans, root_orient, body_pose, betas=None):
6
+ """
7
+ Forward pass of the SMPL model and populates pred_data accordingly with
8
+ joints3d, verts3d, points3d.
9
+
10
+ trans : B x T x 3
11
+ root_orient : B x T x 3
12
+ body_pose : B x T x J*3
13
+ betas : (optional) B x D
14
+ """
15
+ B, T, _ = trans.shape
16
+ bm_batch_size = body_model.bm.batch_size
17
+ assert bm_batch_size % B == 0
18
+ seq_len = bm_batch_size // B
19
+ bm_num_betas = body_model.bm.num_betas
20
+ J_BODY = len(SMPL_JOINTS) - 1 # all joints except root
21
+ if T == 1:
22
+ # must expand to use with body model
23
+ trans = trans.expand(B, seq_len, 3)
24
+ root_orient = root_orient.expand(B, seq_len, 3)
25
+ body_pose = body_pose.expand(B, seq_len, J_BODY * 3)
26
+ elif T != seq_len:
27
+ trans, root_orient, body_pose = zero_pad_tensors(
28
+ [trans, root_orient, body_pose], seq_len - T
29
+ )
30
+ if betas is None:
31
+ betas = torch.zeros(B, bm_num_betas, device=trans.device)
32
+ betas = betas.reshape((B, 1, bm_num_betas)).expand((B, seq_len, bm_num_betas))
33
+ smpl_body = body_model(
34
+ pose_body=body_pose.reshape((B * seq_len, -1)),
35
+ pose_hand=None,
36
+ betas=betas.reshape((B * seq_len, -1)),
37
+ root_orient=root_orient.reshape((B * seq_len, -1)),
38
+ trans=trans.reshape((B * seq_len, -1)),
39
+ )
40
+ return {
41
+ "joints": smpl_body.Jtr.reshape(B, seq_len, -1, 3)[:, :T],
42
+ "vertices": smpl_body.v.reshape(B, seq_len, -1, 3)[:, :T],
43
+ "faces": smpl_body.f,
44
+ }
45
+
46
+
47
+ def zero_pad_tensors(pad_list, pad_size):
48
+ """
49
+ Assumes tensors in pad_list are B x T x D and pad temporal dimension
50
+ """
51
+ B = pad_list[0].size(0)
52
+ new_pad_list = []
53
+ for pad_idx, pad_tensor in enumerate(pad_list):
54
+ padding = torch.zeros((B, pad_size, pad_tensor.size(2))).to(pad_tensor)
55
+ new_pad_list.append(torch.cat([pad_tensor, padding], dim=1))
56
+ return new_pad_list
slahmr/slahmr/confs/config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: posetrack
3
+ - optim
4
+ - _self_
5
+
6
+ model:
7
+ floor_type: "shared"
8
+ est_floor: False
9
+ use_init: True
10
+ opt_cams: False
11
+ opt_scale: True
12
+ async_tracks: True
13
+
14
+ overwrite: False
15
+ run_opt: False
16
+ run_vis: False
17
+ vis:
18
+ phases:
19
+ - motion_chunks
20
+ - input
21
+ render_views:
22
+ - src_cam
23
+ - above
24
+ - side
25
+ make_grid: True
26
+ overwrite: False
27
+
28
+ paths:
29
+ smpl: _DATA/body_models/smplh/neutral/model.npz
30
+ smpl_kid: _DATA/body_models/smpl_kid_template.npy
31
+ vposer: _DATA/body_models/vposer_v1_0
32
+ init_motion_prior: _DATA/humor_ckpts/init_state_prior_gmm
33
+ humor: _DATA/humor_ckpts/humor/best_model.pth
34
+
35
+ humor:
36
+ in_rot_rep: "mat"
37
+ out_rot_rep: "aa"
38
+ latent_size: 48
39
+ model_data_config: "smpl+joints+contacts"
40
+ steps_in: 1
41
+
42
+ fps: 30
43
+ log_root: ../outputs/logs
44
+ log_dir: ${log_root}/${data.type}-${data.split}
45
+ exp_name: ${now:%Y-%m-%d}
46
+
47
+ hydra:
48
+ job:
49
+ chdir: True
50
+ run:
51
+ dir: ${log_dir}/${exp_name}/${data.name}
slahmr/slahmr/confs/data/3dpw.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: 3dpw
2
+ split: test
3
+ seq: downtown_arguing_00
4
+ root: /path/to/3DPW
5
+ use_cams: True
6
+ split_cameras: False
7
+ camera_name: cameras_intrins_split
8
+ shot_idx: 0
9
+ start_idx: 0
10
+ end_idx: 100
11
+ track_ids: "longest-2"
12
+ sources:
13
+ images: ${data.root}/imageFiles/${data.seq}
14
+ cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
15
+ intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
16
+ tracks: ${data.root}/slahmr/track_preds/${data.seq}
17
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
18
+ name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
slahmr/slahmr/confs/data/3dpw_gt.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: 3dpw_gt
2
+ split: test
3
+ seq: downtown_runForBus_00
4
+ root: /path/to/3DPW
5
+ use_cams: True
6
+ split_cameras: False
7
+ camera_name: cameras_intrins_split
8
+ shot_idx: 0
9
+ start_idx: 0
10
+ end_idx: 100
11
+ track_ids: "longest-2"
12
+ sources:
13
+ images: ${data.root}/imageFiles/${data.seq}
14
+ cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
15
+ intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
16
+ tracks: ${data.root}/slahmr/track_gt/${data.seq}
17
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
18
+ name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
slahmr/slahmr/confs/data/custom.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: custom
2
+ split: val
3
+ video: ""
4
+ seq: ""
5
+ root: /path/to/custom
6
+ use_cams: True
7
+ track_ids: "all"
8
+ shot_idx: 0
9
+ start_idx: 0
10
+ end_idx: 200
11
+ split_cameras: True
12
+ name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
13
+ sources:
14
+ images: ${data.root}/images/${data.seq}
15
+ cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
16
+ track: ${data.root}/slahmr/track_preds/${data.seq}
17
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
slahmr/slahmr/confs/data/davis.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: davis
2
+ split: all
3
+ seq: parkour
4
+ root: /path/to/DAVIS
5
+ use_cams: True
6
+ track_ids: "all"
7
+ shot_idx: 0
8
+ start_idx: 0
9
+ end_idx: -1
10
+ split_cameras: True
11
+ name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
12
+ sources:
13
+ images: ${data.root}/JPEGImages/Full-Resolution/${data.seq}
14
+ cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
15
+ tracks: ${data.root}/slahmr/track_preds/${data.seq}
16
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
slahmr/slahmr/confs/data/egobody.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: egobody
2
+ split: val
3
+ seq: recording_20210921_S11_S10_01
4
+ root: /path/to/egobody
5
+ use_cams: True
6
+ camera_name: cameras_intrins_split
7
+ shot_idx: 0
8
+ start_idx: 0
9
+ end_idx: 100
10
+ split_cameras: False
11
+ track_ids: "all"
12
+ sources:
13
+ images: ${data.root}/egocentric_color/${data.seq}/**/PV
14
+ cameras: ${data.root}/slahmr/${data.camera_name}/${data.seq}/${data.start_idx}-${data.end_idx}
15
+ intrins: ${data.root}/slahmr/cameras_gt/${data.seq}/intrinsics.txt
16
+ tracks: ${data.root}/slahmr/track_preds/${data.seq}
17
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
18
+ name: ${data.seq}-${data.track_ids}-${data.start_idx}-${data.end_idx}
slahmr/slahmr/confs/data/posetrack.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: posetrack
2
+ split: val
3
+ seq: 014286_mpii_train
4
+ root: /path/to/posetrack
5
+ use_cams: True
6
+ track_ids: "all"
7
+ shot_idx: 0
8
+ start_idx: 0
9
+ end_idx: -1
10
+ split_cameras: True
11
+ name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}
12
+ track_name: track_preds
13
+ sources:
14
+ images: ${data.root}/images/${data.split}/${data.seq}
15
+ cameras: ${data.root}/slahmr/${data.split}/cameras/${data.seq}/shot-${data.shot_idx}
16
+ tracks: ${data.root}/slahmr/${data.split}/${data.track_name}/${data.seq}
17
+ shots: ${data.root}/slahmr/${data.split}/shot_idcs/${data.seq}.json
slahmr/slahmr/confs/data/video.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ type: video
2
+ split: val
3
+ root: /path/to/data # put your videos in root/videos/vid.mp4
4
+ video_dir: videos
5
+ seq: basketball
6
+ ext: mp4
7
+ src_path: ${data.root}/${data.video_dir}/${data.seq}.${data.ext}
8
+ frame_opts:
9
+ ext: jpg
10
+ fps: 25
11
+ start_sec: 0
12
+ end_sec: -1
13
+ use_cams: True
14
+ track_ids: "all"
15
+ shot_idx: 0
16
+ start_idx: 0
17
+ end_idx: 180
18
+ split_cameras: True
19
+ name: ${data.seq}-${data.track_ids}-shot-${data.shot_idx}-${data.start_idx}-${data.end_idx}
20
+ sources:
21
+ images: ${data.root}/images/${data.seq}
22
+ cameras: ${data.root}/slahmr/cameras/${data.seq}/shot-${data.shot_idx}
23
+ tracks: ${data.root}/slahmr/track_preds/${data.seq}
24
+ shots: ${data.root}/slahmr/shot_idcs/${data.seq}.json
slahmr/slahmr/confs/init.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: posetrack
3
+ - _self_
4
+
5
+ gap: 1
6
+ log_root: outputs
7
+ save_per_frame: False
8
+ print_err: False
9
+ stride: 48
10
+
11
+ hydra:
12
+ run:
13
+ dir: ${log_root}/init/${data.type}-${data.seq}-${data.depth_dir}-${data.fov}fov-gap${gap}
slahmr/slahmr/confs/optim.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ optim:
2
+ options:
3
+ robust_loss_type: "bisquare"
4
+ robust_tuning_const: 4.6851
5
+ joints2d_sigma: 100.0
6
+ lr: 1.0
7
+ lbfgs_max_iter: 20
8
+ save_every: 20
9
+ vis_every: -1
10
+ max_chunk_steps: 20
11
+ save_meshes: False
12
+
13
+ root:
14
+ num_iters: 30
15
+
16
+ smpl:
17
+ num_iters: 0
18
+
19
+ smooth:
20
+ opt_scale: False
21
+ num_iters: 60
22
+
23
+ motion_chunks:
24
+ chunk_size: 10
25
+ init_steps: 20
26
+ chunk_steps: 20
27
+ opt_cams: True
28
+
29
+ loss_weights:
30
+ joints2d: [0.001, 0.001, 0.001]
31
+ bg2d: [0.0, 0.000, 0.000]
32
+ cam_R_smooth : [0.0, 0.0, 0.0]
33
+ cam_t_smooth : [0.0, 0.0, 0.0]
34
+ # bg2d: [0.0, 0.0001, 0.0001]
35
+ # cam_R_smooth : [0.0, 1000.0, 1000.0]
36
+ # cam_t_smooth : [0.0, 1000.0, 1000.0]
37
+ joints3d: [0.0, 0.0, 0.0]
38
+ joints3d_smooth: [1.0, 10.0, 0.0]
39
+ joints3d_rollout: [0.0, 0.0, 0.0]
40
+ verts3d: [0.0, 0.0, 0.0]
41
+ points3d: [0.0, 0.0, 0.0]
42
+ pose_prior: [0.04, 0.04, 0.04]
43
+ shape_prior: [0.05, 0.05, 0.05]
44
+ motion_prior: [0.0, 0.0, 0.075]
45
+ init_motion_prior: [0.0, 0.0, 0.075]
46
+ joint_consistency: [0.0, 0.0, 100.0]
47
+ bone_length: [0.0, 0.0, 2000.0]
48
+ contact_vel: [0.0, 0.0, 100.0]
49
+ contact_height: [0.0, 0.0, 10.0]
50
+ floor_reg: [0.0, 0.0, 0.0]
51
+ # floor_reg: [0.0, 0.0, 0.167]
slahmr/slahmr/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dataset import *
2
+ from . import tools
slahmr/slahmr/data/dataset.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import typing
4
+
5
+ import imageio
6
+ import numpy as np
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset
12
+
13
+ from body_model import OP_NUM_JOINTS, SMPL_JOINTS
14
+ from util.logger import Logger
15
+ from geometry.camera import invert_camera
16
+
17
+ from .tools import read_keypoints, read_mask_path, load_smpl_preds
18
+ from .vidproc import preprocess_cameras, preprocess_frames, preprocess_tracks
19
+
20
+
21
+ """
22
+ Define data-related constants
23
+ """
24
+ DEFAULT_GROUND = np.array([0.0, -1.0, 0.0, -0.5])
25
+
26
+ # XXX: TEMPORARY CONSTANTS
27
+ SHOT_PAD = 0
28
+ MIN_SEQ_LEN = 20
29
+ MAX_NUM_TRACKS = 12
30
+ MIN_TRACK_LEN = 20
31
+ MIN_KEYP_CONF = 0.4
32
+
33
+
34
+ def get_dataset_from_cfg(cfg):
35
+ args = cfg.data
36
+ if not args.use_cams:
37
+ args.sources.cameras = ""
38
+
39
+ args.sources = expand_source_paths(args.sources)
40
+ print("DATA SOURCES", args.sources)
41
+ check_data_sources(args)
42
+ return MultiPeopleDataset(
43
+ args.sources,
44
+ args.seq,
45
+ tid_spec=args.track_ids,
46
+ shot_idx=args.shot_idx,
47
+ start_idx=int(args.start_idx),
48
+ end_idx=int(args.end_idx),
49
+ split_cameras=args.get("split_cameras", True),
50
+ )
51
+
52
+
53
+ def expand_source_paths(data_sources):
54
+ return {k: get_data_source(v) for k, v in data_sources.items()}
55
+
56
+
57
+ def get_data_source(source):
58
+ matches = glob.glob(source)
59
+ if len(matches) < 1:
60
+ print(f"{source} does not exist")
61
+ return source # return anyway for default values
62
+ if len(matches) > 1:
63
+ raise ValueError(f"{source} is not unique")
64
+ return matches[0]
65
+
66
+
67
+ def check_data_sources(args):
68
+ if args.type == "video":
69
+ preprocess_frames(args.sources.images, args.src_path, **args.frame_opts)
70
+ preprocess_tracks(args.sources.images, args.sources.tracks, args.sources.shots)
71
+ preprocess_cameras(args, overwrite=args.get("overwrite_cams", False))
72
+
73
+
74
+ class MultiPeopleDataset(Dataset):
75
+ def __init__(
76
+ self,
77
+ data_sources: typing.Dict,
78
+ seq_name,
79
+ tid_spec="all",
80
+ shot_idx=0,
81
+ start_idx=0,
82
+ end_idx=-1,
83
+ pad_shot=False,
84
+ split_cameras=True,
85
+ ):
86
+ self.seq_name = seq_name
87
+ self.data_sources = data_sources
88
+ self.split_cameras = split_cameras
89
+
90
+ # select only images in the desired shot
91
+ img_files, _ = get_shot_img_files(
92
+ self.data_sources["shots"], shot_idx, pad_shot
93
+ )
94
+ end_idx = end_idx if end_idx > 0 else len(img_files)
95
+ self.data_start, self.data_end = start_idx, end_idx
96
+ img_files = img_files[start_idx:end_idx]
97
+ self.img_names = [get_name(f) for f in img_files]
98
+ self.num_imgs = len(self.img_names)
99
+
100
+ img_dir = self.data_sources["images"]
101
+ assert os.path.isdir(img_dir)
102
+ self.img_paths = [os.path.join(img_dir, f) for f in img_files]
103
+ img_h, img_w = imageio.imread(self.img_paths[0]).shape[:2]
104
+ self.img_size = img_w, img_h
105
+ print(f"USING TOTAL {self.num_imgs} {img_w}x{img_h} IMGS")
106
+
107
+ # find the tracks in the video
108
+ track_root = self.data_sources["tracks"]
109
+ if tid_spec == "all" or tid_spec.startswith("longest"):
110
+ n_tracks = MAX_NUM_TRACKS
111
+ if tid_spec.startswith("longest"):
112
+ n_tracks = int(tid_spec.split("-")[1])
113
+ # get the longest tracks in the selected shot
114
+ track_ids = sorted(os.listdir(track_root))
115
+ track_paths = [
116
+ [f"{track_root}/{tid}/{name}_keypoints.json" for name in self.img_names]
117
+ for tid in track_ids
118
+ ]
119
+ track_lens = [
120
+ len(list(filter(os.path.isfile, paths))) for paths in track_paths
121
+ ]
122
+ track_ids = [
123
+ track_ids[i]
124
+ for i in np.argsort(track_lens)[::-1]
125
+ if track_lens[i] > MIN_TRACK_LEN
126
+ ]
127
+ print("TRACK LENGTHS", track_ids, track_lens)
128
+ track_ids = track_ids[:n_tracks]
129
+ else:
130
+ track_ids = [f"{int(tid):03d}" for tid in tid_spec.split("-")]
131
+
132
+ print("TRACK IDS", track_ids)
133
+
134
+ self.track_ids = track_ids
135
+ self.n_tracks = len(track_ids)
136
+ self.track_dirs = [os.path.join(track_root, tid) for tid in track_ids]
137
+
138
+ # keep a list of frame index masks of whether a track is available in a frame
139
+ sidx = np.inf
140
+ eidx = -1
141
+ self.track_vis_masks = []
142
+ for pred_dir in self.track_dirs:
143
+ kp_paths = [f"{pred_dir}/{x}_keypoints.json" for x in self.img_names]
144
+ has_kp = [os.path.isfile(x) for x in kp_paths]
145
+
146
+ # keep track of which frames this track is visible in
147
+ vis_mask = np.array(has_kp)
148
+ idcs = np.where(vis_mask)[0]
149
+ if len(idcs) > 0:
150
+ si, ei = min(idcs), max(idcs)
151
+ sidx = min(sidx, si)
152
+ eidx = max(eidx, ei)
153
+ self.track_vis_masks.append(vis_mask)
154
+
155
+ eidx = max(eidx + 1, 0)
156
+ sidx = min(sidx, eidx)
157
+ print("START", sidx, "END", eidx)
158
+ self.start_idx = sidx
159
+ self.end_idx = eidx
160
+ self.seq_len = eidx - sidx
161
+ self.seq_intervals = [(sidx, eidx) for _ in track_ids]
162
+
163
+ self.sel_img_paths = self.img_paths[sidx:eidx]
164
+ self.sel_img_names = self.img_names[sidx:eidx]
165
+
166
+ # used to cache data
167
+ self.data_dict = {}
168
+ self.cam_data = None
169
+
170
+ def __len__(self):
171
+ return self.n_tracks
172
+
173
+ def load_data(self, interp_input=True):
174
+ if len(self.data_dict) > 0:
175
+ return
176
+
177
+ # load camera data
178
+ self.load_camera_data()
179
+ # get data for each track
180
+ data_out = {
181
+ "mask_paths": [],
182
+ "floor_plane": [],
183
+ "joints2d": [],
184
+ "vis_mask": [],
185
+ "track_interval": [],
186
+ "init_body_pose": [],
187
+ "init_root_orient": [],
188
+ "init_trans": [],
189
+ }
190
+
191
+ # create batches of sequences
192
+ # each batch is a track for a person
193
+ T = self.seq_len
194
+ sidx, eidx = self.start_idx, self.end_idx
195
+ for i, tid in enumerate(self.track_ids):
196
+ # load mask of visible frames for this track
197
+ vis_mask = self.track_vis_masks[i][sidx:eidx] # (T)
198
+ vis_idcs = np.where(vis_mask)[0]
199
+ track_s, track_e = min(vis_idcs), max(vis_idcs) + 1
200
+ data_out["track_interval"].append([track_s, track_e])
201
+
202
+ vis_mask = get_ternary_mask(vis_mask)
203
+ data_out["vis_mask"].append(vis_mask)
204
+
205
+ # load 2d keypoints for visible frames
206
+ kp_paths = [
207
+ f"{self.track_dirs[i]}/{x}_keypoints.json" for x in self.sel_img_names
208
+ ]
209
+ # (T, J, 3) (x, y, conf)
210
+ joints2d_data = np.stack(
211
+ [read_keypoints(p) for p in kp_paths], axis=0
212
+ ).astype(np.float32)
213
+ # Discard bad ViTPose detections
214
+ joints2d_data[
215
+ np.repeat(joints2d_data[:, :, [2]] < MIN_KEYP_CONF, 3, axis=2)
216
+ ] = 0
217
+ data_out["joints2d"].append(joints2d_data)
218
+
219
+ # load single image smpl predictions
220
+ pred_paths = [
221
+ f"{self.track_dirs[i]}/{x}_smpl.json" for x in self.sel_img_names
222
+ ]
223
+ pose_init, orient_init, trans_init, _ = load_smpl_preds(
224
+ pred_paths, interp=interp_input
225
+ )
226
+
227
+ n_joints = len(SMPL_JOINTS) - 1
228
+ data_out["init_body_pose"].append(pose_init[:, :n_joints, :])
229
+ data_out["init_root_orient"].append(orient_init)
230
+ data_out["init_trans"].append(trans_init)
231
+
232
+ data_out["floor_plane"].append(DEFAULT_GROUND[:3] * DEFAULT_GROUND[3:])
233
+
234
+ self.data_dict = data_out
235
+
236
+ def __getitem__(self, idx):
237
+ if len(self.data_dict) < 1:
238
+ self.load_data()
239
+
240
+ obs_data = dict()
241
+
242
+ # 2D keypoints
243
+ joint2d_data = self.data_dict["joints2d"][idx]
244
+ obs_data["joints2d"] = torch.Tensor(joint2d_data)
245
+
246
+ # single frame predictions
247
+ obs_data["init_body_pose"] = torch.Tensor(self.data_dict["init_body_pose"][idx])
248
+ obs_data["init_root_orient"] = torch.Tensor(
249
+ self.data_dict["init_root_orient"][idx]
250
+ )
251
+ obs_data["init_trans"] = torch.Tensor(self.data_dict["init_trans"][idx])
252
+
253
+ # floor plane
254
+ obs_data["floor_plane"] = torch.Tensor(self.data_dict["floor_plane"][idx])
255
+
256
+ # the frames the track is visible in
257
+ obs_data["vis_mask"] = torch.Tensor(self.data_dict["vis_mask"][idx])
258
+
259
+ # the frames used in this subsequence
260
+ obs_data["seq_interval"] = torch.Tensor(list(self.seq_intervals[idx])).to(
261
+ torch.int
262
+ )
263
+ # the start and end interval of available keypoints
264
+ obs_data["track_interval"] = torch.Tensor(
265
+ self.data_dict["track_interval"][idx]
266
+ ).int()
267
+
268
+ obs_data["track_id"] = int(self.track_ids[idx])
269
+ obs_data["seq_name"] = self.seq_name
270
+ return obs_data
271
+
272
+ def load_camera_data(self):
273
+ cam_dir = self.data_sources["cameras"]
274
+ data_interval = 0, -1
275
+ if self.split_cameras:
276
+ data_interval = self.data_start, self.data_end
277
+ track_interval = self.start_idx, self.end_idx
278
+ self.cam_data = CameraData(
279
+ cam_dir, self.seq_len, self.img_size, data_interval, track_interval
280
+ )
281
+
282
+ def get_camera_data(self):
283
+ if self.cam_data is None:
284
+ raise ValueError
285
+ return self.cam_data.as_dict()
286
+
287
+
288
+ class CameraData(object):
289
+ def __init__(
290
+ self, cam_dir, seq_len, img_size, data_interval=[0, -1], track_interval=[0, -1]
291
+ ):
292
+ self.img_size = img_size
293
+ self.cam_dir = cam_dir
294
+
295
+ # inclusive exclusive
296
+ data_start, data_end = data_interval
297
+ if data_end < 0:
298
+ data_end += seq_len + 1
299
+ data_len = data_end - data_start
300
+
301
+ # start and end indices are with respect to the data interval
302
+ sidx, eidx = track_interval
303
+ if eidx < 0:
304
+ eidx += data_len + 1
305
+ self.sidx, self.eidx = sidx + data_start, eidx + data_start
306
+ self.seq_len = self.eidx - self.sidx
307
+
308
+ self.load_data()
309
+
310
+ def load_data(self):
311
+ # camera info
312
+ sidx, eidx = self.sidx, self.eidx
313
+ img_w, img_h = self.img_size
314
+ fpath = os.path.join(self.cam_dir, "cameras.npz")
315
+ if os.path.isfile(fpath):
316
+ Logger.log(f"Loading cameras from {fpath}...")
317
+ cam_R, cam_t, intrins, width, height = load_cameras_npz(fpath)
318
+ scale = img_w / width
319
+ self.intrins = scale * intrins[sidx:eidx]
320
+ # move first camera to origin
321
+ # R0, t0 = invert_camera(cam_R[sidx], cam_t[sidx])
322
+ # self.cam_R = torch.einsum("ij,...jk->...ik", R0, cam_R[sidx:eidx])
323
+ # self.cam_t = t0 + torch.einsum("ij,...j->...i", R0, cam_t[sidx:eidx])
324
+ # t0 = -cam_t[sidx:eidx].mean(dim=0) + torch.randn(3) * 0.1
325
+ t0 = -cam_t[sidx:sidx+1] + torch.randn(3) * 0.1
326
+ self.cam_R = cam_R[sidx:eidx]
327
+ self.cam_t = cam_t[sidx:eidx] - t0
328
+ self.is_static = False
329
+ else:
330
+ Logger.log(f"WARNING: {fpath} does not exist, using static cameras...")
331
+ default_focal = 0.5 * (img_h + img_w)
332
+ self.intrins = torch.tensor(
333
+ [default_focal, default_focal, img_w / 2, img_h / 2]
334
+ )[None].repeat(self.seq_len, 1)
335
+
336
+ self.cam_R = torch.eye(3)[None].repeat(self.seq_len, 1, 1)
337
+ self.cam_t = torch.zeros(self.seq_len, 3)
338
+ self.is_static = True
339
+
340
+ Logger.log(f"Images have {img_w}x{img_h}, intrins {self.intrins[0]}")
341
+ print("CAMERA DATA", self.cam_R.shape, self.cam_t.shape, self.intrins[0])
342
+
343
+ def world2cam(self):
344
+ return self.cam_R, self.cam_t
345
+
346
+ def cam2world(self):
347
+ R = self.cam_R.transpose(-1, -2)
348
+ t = -torch.einsum("bij,bj->bi", R, self.cam_t)
349
+ return R, t
350
+
351
+ def as_dict(self):
352
+ return {
353
+ "cam_R": self.cam_R, # (T, 3, 3)
354
+ "cam_t": self.cam_t, # (T, 3)
355
+ "intrins": self.intrins, # (T, 4)
356
+ "static": self.is_static, # bool
357
+ }
358
+
359
+
360
+ def get_ternary_mask(vis_mask):
361
+ # get the track start and end idcs relative to the filtered interval
362
+ vis_mask = torch.as_tensor(vis_mask)
363
+ vis_idcs = torch.where(vis_mask)[0]
364
+ track_s, track_e = min(vis_idcs), max(vis_idcs) + 1
365
+ # -1 = track out of scene, 0 = occlusion, 1 = visible
366
+ vis_mask = vis_mask.float()
367
+ vis_mask[:track_s] = -1
368
+ vis_mask[track_e:] = -1
369
+ return vis_mask
370
+
371
+
372
+ def get_shot_img_files(shots_path, shot_idx, shot_pad=SHOT_PAD):
373
+ assert os.path.isfile(shots_path)
374
+ with open(shots_path, "r") as f:
375
+ shots_dict = json.load(f)
376
+ img_names = sorted(shots_dict.keys())
377
+ N = len(img_names)
378
+ shot_mask = np.array([shots_dict[x] == shot_idx for x in img_names])
379
+
380
+ idcs = np.where(shot_mask)[0]
381
+ if shot_pad > 0: # drop the frames before/after shot change
382
+ if min(idcs) > 0:
383
+ idcs = idcs[shot_pad:]
384
+ if len(idcs) > 0 and max(idcs) < N - 1:
385
+ idcs = idcs[:-shot_pad]
386
+ if len(idcs) < MIN_SEQ_LEN:
387
+ raise ValueError("shot is too short for optimization")
388
+
389
+ shot_mask = np.zeros(N, dtype=bool)
390
+ shot_mask[idcs] = 1
391
+ sel_paths = [img_names[i] for i in idcs]
392
+ print(f"FOUND {len(idcs)}/{len(shots_dict)} FRAMES FOR SHOT {shot_idx}")
393
+ return sel_paths, idcs
394
+
395
+
396
+ def load_cameras_npz(camera_path):
397
+ assert os.path.splitext(camera_path)[-1] == ".npz"
398
+
399
+ cam_data = np.load(camera_path)
400
+ height, width, focal = (
401
+ int(cam_data["height"]),
402
+ int(cam_data["width"]),
403
+ float(cam_data["focal"]),
404
+ )
405
+
406
+ w2c = torch.from_numpy(cam_data["w2c"]) # (N, 4, 4)
407
+ cam_R = w2c[:, :3, :3] # (N, 3, 3)
408
+ cam_t = w2c[:, :3, 3] # (N, 3)
409
+ N = len(w2c)
410
+
411
+ if "intrins" in cam_data:
412
+ intrins = torch.from_numpy(cam_data["intrins"].astype(np.float32))
413
+ else:
414
+ intrins = torch.tensor([focal, focal, width / 2, height / 2])[None].repeat(N, 1)
415
+
416
+ print(f"Loaded {N} cameras")
417
+ return cam_R, cam_t, intrins, width, height
418
+
419
+
420
+ def is_image(x):
421
+ return (x.endswith(".png") or x.endswith(".jpg")) and not x.startswith(".")
422
+
423
+
424
+ def get_name(x):
425
+ return os.path.splitext(os.path.basename(x))[0]
426
+
427
+
428
+ def split_name(x, suffix):
429
+ return os.path.basename(x).split(suffix)[0]
430
+
431
+
432
+ def get_names_in_dir(d, suffix):
433
+ files = [split_name(x, suffix) for x in glob.glob(f"{d}/*{suffix}")]
434
+ return sorted(files)
435
+
436
+
437
+ def batch_join(parent, names, suffix=""):
438
+ return [os.path.join(parent, f"{n}{suffix}") for n in names]
slahmr/slahmr/data/tools.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import functools
4
+
5
+ import numpy as np
6
+
7
+ from body_model import OP_NUM_JOINTS
8
+ from scipy.interpolate import interp1d
9
+ from scipy.spatial.transform import Rotation, Slerp
10
+
11
+
12
+ def read_keypoints(keypoint_fn):
13
+ """
14
+ Only reads body keypoint data of first person.
15
+ """
16
+ empty_kps = np.zeros((OP_NUM_JOINTS, 3), dtype=np.float)
17
+ if not os.path.isfile(keypoint_fn):
18
+ return empty_kps
19
+
20
+ with open(keypoint_fn) as keypoint_file:
21
+ data = json.load(keypoint_file)
22
+
23
+ if len(data["people"]) == 0:
24
+ print("WARNING: Found no keypoints in %s! Returning zeros!" % (keypoint_fn))
25
+ return empty_kps
26
+
27
+ person_data = data["people"][0]
28
+ body_keypoints = np.array(person_data["pose_keypoints_2d"], dtype=np.float)
29
+ body_keypoints = body_keypoints.reshape([-1, 3])
30
+ return body_keypoints
31
+
32
+
33
+ def read_mask_path(path):
34
+ mask_path = None
35
+ if not os.path.isfile(path):
36
+ return mask_path
37
+
38
+ with open(path, "r") as f:
39
+ data = json.load(path)
40
+
41
+ person_data = data["people"][0]
42
+ if "mask_path" in person_data:
43
+ mask_path = person_data["mask_path"]
44
+
45
+ return mask_path
46
+
47
+
48
+ def read_smpl_preds(pred_path, num_betas=10):
49
+ """
50
+ reads the betas, body_pose, global orientation and translation of a smpl prediction
51
+ exported from phalp outputs
52
+ returns betas (10,), body_pose (23, 3), global_orientation (3,), translation (3,)
53
+ """
54
+ pose = np.zeros((23, 3))
55
+ rot = np.zeros(3)
56
+ trans = np.zeros(3)
57
+ betas = np.zeros(num_betas)
58
+ if not os.path.isfile(pred_path):
59
+ return pose, rot, trans, betas
60
+
61
+ with open(pred_path, "r") as f:
62
+ data = json.load(f)
63
+
64
+ if "body_pose" in data:
65
+ pose = np.array(data["body_pose"], dtype=np.float32)
66
+
67
+ if "global_orient" in data:
68
+ rot = np.array(data["global_orient"], dtype=np.float32)
69
+
70
+ if "cam_trans" in data:
71
+ trans = np.array(data["cam_trans"], dtype=np.float32)
72
+
73
+ if "betas" in data:
74
+ betas = np.array(data["betas"], dtype=np.float32)
75
+
76
+ return pose, rot, trans, betas
77
+
78
+
79
+ def load_smpl_preds(pred_paths, interp=True, num_betas=10):
80
+ vis_mask = np.array([os.path.isfile(x) for x in pred_paths])
81
+ vis_idcs = np.where(vis_mask)[0]
82
+
83
+ # load single image smpl predictions
84
+ stack_fnc = functools.partial(np.stack, axis=0)
85
+ # (N, 23, 3), (N, 3), (N, 3), (N, 10)
86
+ pose, orient, trans, betas = map(
87
+ stack_fnc, zip(*[read_smpl_preds(p, num_betas=num_betas) for p in pred_paths])
88
+ )
89
+ if not interp:
90
+ return pose, orient, trans, betas
91
+
92
+ # interpolate the occluded tracks
93
+ orient_slerp = Slerp(vis_idcs, Rotation.from_rotvec(orient[vis_idcs]))
94
+ trans_interp = interp1d(vis_idcs, trans[vis_idcs], axis=0)
95
+ betas_interp = interp1d(vis_idcs, betas[vis_idcs], axis=0)
96
+
97
+ tmin, tmax = min(vis_idcs), max(vis_idcs) + 1
98
+ times = np.arange(tmin, tmax)
99
+ orient[times] = orient_slerp(times).as_rotvec()
100
+ trans[times] = trans_interp(times)
101
+ betas[times] = betas_interp(times)
102
+
103
+ # interpolate for each joint angle
104
+ for i in range(pose.shape[1]):
105
+ pose_slerp = Slerp(vis_idcs, Rotation.from_rotvec(pose[vis_idcs, i]))
106
+ pose[times, i] = pose_slerp(times).as_rotvec()
107
+
108
+ return pose, orient, trans, betas
slahmr/slahmr/data/vidproc.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import subprocess
4
+
5
+ import preproc.launch_phalp as phalp
6
+ from preproc.launch_slam import split_frames_shots, get_command, check_intrins
7
+ from preproc.extract_frames import video_to_frames
8
+
9
+
10
+ def is_nonempty(d):
11
+ return os.path.isdir(d) and len(os.listdir(d)) > 0
12
+
13
+
14
+ def preprocess_frames(img_dir, src_path, overwrite=False, **kwargs):
15
+ if not overwrite and is_nonempty(img_dir):
16
+ print(f"FOUND {len(os.listdir(img_dir))} FRAMES in {img_dir}")
17
+ return
18
+ print(f"EXTRACTING FRAMES FROM {src_path} TO {img_dir}")
19
+ print(kwargs)
20
+ out = video_to_frames(src_path, img_dir, overwrite=overwrite, **kwargs)
21
+ assert out == 0, "FAILED FRAME EXTRACTION"
22
+
23
+
24
+ def preprocess_tracks(img_dir, track_dir, shot_dir, overwrite=False):
25
+ """
26
+ :param img_dir
27
+ :param track_dir, expected format: res_root/track_name/sequence
28
+ :param shot_dir, expected format: res_root/shot_name/sequence
29
+ """
30
+ if not overwrite and is_nonempty(track_dir):
31
+ print(f"FOUND TRACKS IN {track_dir}")
32
+ return
33
+
34
+ print(f"RUNNING PHALP ON {img_dir}")
35
+ track_root, seq = os.path.split(track_dir.rstrip("/"))
36
+ res_root, track_name = os.path.split(track_root)
37
+ shot_name = shot_dir.rstrip("/").split("/")[-2]
38
+ gpu = os.environ.get("CUDA_VISIBLE_DEVICES", 0)
39
+
40
+ phalp.process_seq(
41
+ [gpu],
42
+ seq,
43
+ img_dir,
44
+ f"{res_root}/phalp_out",
45
+ track_name=track_name,
46
+ shot_name=shot_name,
47
+ overwrite=overwrite,
48
+ )
49
+
50
+
51
+ def preprocess_cameras(cfg, overwrite=False):
52
+ if not overwrite and is_nonempty(cfg.sources.cameras):
53
+ print(f"FOUND CAMERAS IN {cfg.sources.cameras}")
54
+ return
55
+
56
+ print(f"RUNNING SLAM ON {cfg.seq}")
57
+ img_dir = cfg.sources.images
58
+ map_dir = cfg.sources.cameras
59
+ subseqs, shot_idcs = split_frames_shots(cfg.sources.images, cfg.sources.shots)
60
+ shot_idx = np.where(shot_idcs == cfg.shot_idx)[0][0]
61
+ # run on selected shot
62
+ start, end = subseqs[shot_idx]
63
+ if not cfg.split_cameras:
64
+ # only run on specified segment within shot
65
+ end = start + cfg.end_idx
66
+ start = start + cfg.start_idx
67
+ intrins_path = cfg.sources.get("intrins", None)
68
+ if intrins_path is not None:
69
+ intrins_path = check_intrins(cfg.type, cfg.root, intrins_path, cfg.seq, cfg.split)
70
+
71
+ cmd = get_command(
72
+ img_dir,
73
+ map_dir,
74
+ start=start,
75
+ end=end,
76
+ intrins_path=intrins_path,
77
+ overwrite=overwrite,
78
+ )
79
+ print(cmd)
80
+ gpu = os.environ.get("CUDA_VISIBLE_DEVICES", 0)
81
+ out = subprocess.call(f"CUDA_VISIBLE_DEVICES={gpu} {cmd}", shell=True)
82
+ assert out == 0, "SLAM FAILED"
slahmr/slahmr/eval/__init__.py ADDED
File without changes
slahmr/slahmr/eval/associate.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import joblib
5
+ import numpy as np
6
+ import torch
7
+
8
+ from data.tools import read_keypoints
9
+
10
+
11
+ def associate_phalp_track_dirs(
12
+ phalp_dir, img_dir, track_ids, gt_kps, start=0, end=-1, debug=False
13
+ ):
14
+ """
15
+ Associate the M track_ids with G GT tracks
16
+ returns (M, T) array of best matching GT person index
17
+ :param phalp_dir (str) directory with phalp track folders
18
+ :param img_dir (str) directory with source images
19
+ :param track ids (list) M tracks to match
20
+ :param gt_kps (G, T, J, 3) gt keypoints for G people, T times, J joints
21
+ :param start (optional int default 0)
22
+ :param end (optional int default -1)
23
+ """
24
+ img_names = sorted([os.path.splitext(x)[0] for x in os.listdir(img_dir)])
25
+ N = len(img_names)
26
+ end = N + 1 + end if end < 0 else end
27
+ sel_imgs = img_names[start:end]
28
+ G, T = gt_kps.shape[:2] # G num people, T num frames
29
+ assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
30
+
31
+ track_ids = [f"{int(tid):03d}" for tid in track_ids]
32
+ M = len(track_ids)
33
+ # find the best matching GT track for each PHALP track
34
+ match_idcs = torch.full((M, T), -1)
35
+ for t, frame_name in enumerate(sel_imgs):
36
+ track_kps = [] # get track keypoints
37
+ for tid in track_ids:
38
+ kp_path = f"{phalp_dir}/{tid}/{frame_name}_keypoints.json"
39
+ track_kps.append(read_keypoints(kp_path))
40
+ track_kps = np.stack(track_kps, axis=0) # (M, 25, 3)
41
+ for g in range(G):
42
+ kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
43
+ m = associate_keypoints(kp_gt, track_kps, debug=debug)
44
+ if m == -1:
45
+ continue
46
+ match_idcs[m, t] = g
47
+ return match_idcs
48
+
49
+
50
+ def associate_phalp_track_data(
51
+ phalp_file, track_ids, gt_kps, start=0, end=-1, debug=False
52
+ ):
53
+ """
54
+ Get the best GT person for each phalp track
55
+ :param phalp_file (path) to phalp result pickle file
56
+ :param gt_kps (G, T, 3, 18) gt keypoints
57
+ :param track_ids (list) of phalp track ids
58
+ :param start (optional int)
59
+ :param end (optional int)
60
+ return (M, T) array the matching GT person index for each phalp track
61
+ """
62
+ data = joblib.load(phalp_file)
63
+ img_names = sorted(data.keys())
64
+ N = len(img_names) # number of frames
65
+ end = N + 1 + end if end < 0 else end
66
+ sel_imgs = img_names[start:end]
67
+
68
+ G, T = gt_kps.shape[:2] # G num people, T num frames
69
+ assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
70
+
71
+ M = len(track_ids)
72
+ track_idcs = {tid: m for m, tid in enumerate(track_ids)}
73
+ # get the best matching GT track for each PHALP track
74
+ match_idcs = torch.full((M, T), -1)
75
+ for t, frame_name in enumerate(sel_imgs):
76
+ frame_data = data[frame_name]
77
+ for g in range(G):
78
+ kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
79
+ # get the best track ID for the GT person
80
+ tid = associate_frame_dict(frame_data, kp_gt, track_ids, debug=debug)
81
+ if tid == -1:
82
+ continue
83
+ m = track_idcs[tid]
84
+ match_idcs[m, t] = g
85
+ return match_idcs
86
+
87
+
88
+ def associate_keypoints(gt_kps, track_kps, debug=False):
89
+ """
90
+ :param gt_bbox (25, 3)
91
+ :param track_bboxes (M, 25, 3)
92
+ return the index of the best overlapping track bbox
93
+ """
94
+ gt_kps = gt_kps[gt_kps[:, 2] > 0, :2]
95
+ if len(gt_kps) < 1:
96
+ return -1
97
+ bb_min, bb_max = gt_kps.min(axis=0), gt_kps.max(axis=0)
98
+ gt_bbox = np.concatenate([bb_min, bb_max], axis=-1) # (4,)
99
+
100
+ track_kps = track_kps[..., :2] # (M, 25, 2)
101
+ track_min, track_max = track_kps.min(axis=1), track_kps.max(axis=1)
102
+ track_bboxes = np.concatenate([track_min, track_max], axis=-1) # (M, 4)
103
+
104
+ ious = np.stack([compute_iou(bb, gt_bbox)[0] for bb in track_bboxes], axis=0)
105
+ return np.argmax(ious)
106
+
107
+
108
+ def associate_frame_dict(frame_data, gt_kps, track_ids, debug=False):
109
+ """
110
+ For the GT keypoints, find the PHALP track in track_ids with best overlap
111
+ :param frame_data (dict) PHALP output data
112
+ :param gt_kps (25, 3)
113
+ :param track_ids (list of N) PHALP track ids to search over
114
+ return the id in track_ids with the biggest overlap with gt_kps
115
+ """
116
+ gt_kps = gt_kps[gt_kps[:, 2] > 0, :2]
117
+ if len(gt_kps) < 1:
118
+ return -1
119
+ bb_min, bb_max = gt_kps.min(axis=0), gt_kps.max(axis=0)
120
+ gt_bbox = np.concatenate([bb_min, bb_max], axis=-1) # (4,)
121
+
122
+ # use strs for track ids
123
+ tid_strs = [str(tid) for tid in track_ids]
124
+ # get the list indices of the PHALP tracks
125
+ track_idcs = {
126
+ str(int(tid)): i
127
+ for i, tid in enumerate(frame_data["tid"])
128
+ if tid in frame_data["tracked_ids"]
129
+ }
130
+ # select the track with the biggest overlap with the gt kps
131
+ ious = []
132
+ for tid in track_ids:
133
+ if tid not in track_idcs:
134
+ ious.append(0)
135
+ continue
136
+ bb = frame_data["bbox"][track_idcs[tid]] # (min_x, min_y, w, h)
137
+ bbox = np.concatenate([bb[:2], bb[:2] + bb[2:]], axis=-1)
138
+ iou = compute_iou(bbox, gt_bbox)[0]
139
+ ious.append(iou)
140
+ ious = np.stack(ious, axis=0)
141
+ idx = np.argmax(ious)
142
+ if debug:
143
+ print(track_ids[idx], track_ids, ious)
144
+ return track_ids[idx]
145
+
146
+
147
+ def compute_iou(bb1, bb2):
148
+ """
149
+ :param bb1 (..., 4) top left x, y bottom right x y
150
+ :param bb2 (..., 4) top left x, y bottom right x y
151
+ return (...) IOU
152
+ """
153
+ x11, y11, x12, y12 = np.split(bb1, 4, axis=-1)
154
+ x21, y21, x22, y22 = np.split(bb2, 4, axis=-1)
155
+ x1 = np.maximum(x11, x21)
156
+ y1 = np.maximum(y11, y21)
157
+ x2 = np.minimum(x12, x22)
158
+ y2 = np.minimum(y12, y22)
159
+ intersect = np.maximum((x2 - x1) * (y2 - y1), 0)
160
+ union = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - intersect
161
+ return intersect / (union + 1e-6)
slahmr/slahmr/eval/egobody_utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import itertools
3
+ import glob
4
+ import pickle
5
+ import json
6
+ import pandas as pd
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from tools import load_body_model, move_to, detach_all, EGOBODY_ROOT
12
+
13
+
14
+ def get_sequence_body_info(seq_name):
15
+ info_file = f"{EGOBODY_ROOT}/data_info_release.csv"
16
+ info_df = pd.read_csv(info_file)
17
+ seq_info = info_df[info_df["recording_name"] == seq_name]
18
+ return seq_info["body_idx_fpv"].values[0]
19
+
20
+
21
+ def get_egobody_split(split):
22
+ split_file = f"{EGOBODY_ROOT}/data_splits.csv"
23
+ split_df = pd.read_csv(split_file)
24
+ if split not in split_df.columns:
25
+ print(f"{split} not in {split_file}")
26
+ return []
27
+ return split_df[split].dropna().tolist()
28
+
29
+
30
+ def get_egobody_seq_paths(seq_name, start=0, end=-1):
31
+ img_dir = get_egobody_img_dir(seq_name)
32
+ # img files are named [timestamp]_frame_[index].jpg
33
+ img_files = sorted(os.listdir(img_dir))
34
+ end = len(img_files) if end < 0 else end
35
+ print(f"FOUND {len(img_files)} FILES FOR SEQ {seq_name}")
36
+ return img_files[start:end]
37
+
38
+
39
+ def get_egobody_seq_names(seq_name, start=0, end=-1):
40
+ img_files = get_egobody_seq_paths(seq_name, start=start, end=end)
41
+ frame_names = ["_".join(x.split(".")[0].split("_")[1:]) for x in img_files]
42
+ return frame_names
43
+
44
+
45
+ def get_egobody_img_dir(seq_name):
46
+ img_dir = f"{EGOBODY_ROOT}/egocentric_color/{seq_name}/**/PV"
47
+ matches = glob.glob(img_dir)
48
+ if len(matches) != 1:
49
+ raise ValueError(f"{img_dir} has {len(matches)} matches!")
50
+ return matches[0]
51
+
52
+
53
+ def get_egobody_keypoints(seq_name, start=0, end=-1):
54
+ img_dir = os.path.dirname(get_egobody_img_dir(seq_name))
55
+ kp_file = f"{img_dir}/keypoints.npz"
56
+ valid_file = f"{img_dir}/valid_frame.npz"
57
+
58
+ # missing keypoints aren't included, must fill in
59
+ kp_dict = {}
60
+ valid_dict = {}
61
+ kp_data = np.load(kp_file)
62
+ valid_data = np.load(valid_file)
63
+
64
+ zeros = np.zeros_like(kp_data["keypoints"][0])
65
+ for img_path, kps in zip(kp_data["imgname"], kp_data["keypoints"]):
66
+ img_name = os.path.basename(img_path)
67
+ kp_dict[img_name] = kps
68
+
69
+ for img_path, valid in zip(valid_data["imgname"], valid_data["valid"]):
70
+ img_name = os.path.basename(img_path)
71
+ valid_dict[img_name] = valid
72
+
73
+ img_paths = sorted(glob.glob(f"{img_dir}/PV/*.jpg"))
74
+ end = len(img_paths) + 1 + end if end < 0 else end
75
+ img_names = [os.path.basename(x) for x in img_paths[start:end]]
76
+ kps = np.stack([kp_dict.get(name, zeros) for name in img_names], axis=0)
77
+ valid = np.stack([valid_dict.get(name, False) for name in img_names], axis=0)
78
+ return kps, valid
79
+
80
+
81
+ def load_egobody_smpl_params(seq_name, start=0, end=-1):
82
+ frame_names = get_egobody_seq_names(seq_name, start=start, end=end)
83
+ body_name = get_sequence_body_info(seq_name)
84
+ body_idx, gender = body_name.split(" ")
85
+ smpl_dir = (
86
+ f"{EGOBODY_ROOT}/smpl_interactee_val/{seq_name}/body_idx_{body_idx}/results"
87
+ )
88
+ if not os.path.isdir(smpl_dir):
89
+ raise ValueError(f"EXPECTED BODY DIR {smpl_dir} DOES NOT EXIST")
90
+
91
+ print(f"LOADING {len(frame_names)} SMPL PARAMS FROM {smpl_dir}")
92
+ smpl_dict = {"trans": [], "root_orient": [], "pose_body": [], "betas": []}
93
+ for frame in frame_names:
94
+ with open(f"{smpl_dir}/{frame}/000.pkl", "rb") as f:
95
+ # data has global_orient, body_pose, betas, transl
96
+ data = pickle.load(f)
97
+ smpl_dict["trans"].append(torch.from_numpy(data["transl"]))
98
+ smpl_dict["pose_body"].append(torch.from_numpy(data["body_pose"]))
99
+ smpl_dict["root_orient"].append(torch.from_numpy(data["global_orient"]))
100
+ smpl_dict["betas"].append(torch.from_numpy(data["betas"]))
101
+ smpl_dict = {k: torch.cat(v, dim=0)[None] for k, v in smpl_dict.items()}
102
+ smpl_dict["genders"] = [gender]
103
+ return smpl_dict
104
+
105
+
106
+ def load_egobody_intrinsics(seq_name, start=0, end=-1, ret_size_tuple=True):
107
+ path = f"{EGOBODY_ROOT}/slahmr/cameras_gt/{seq_name}/intrinsics.txt"
108
+ assert os.path.isfile(path)
109
+ intrins = np.loadtxt(path) # (T, 6)
110
+ end = len(intrins) if end < 0 else end
111
+ intrins = intrins[start:end]
112
+ if ret_size_tuple:
113
+ img_size = intrins[0, 4:].astype(int).tolist() # (2)
114
+ intrins = torch.from_numpy(intrins[:, :4].astype(np.float32))
115
+ return intrins, img_size
116
+ img_size = torch.from_numpy(intrins[:, 4:].astype(int))
117
+ intrins = torch.from_numpy(intrins[:, :4].astype(np.float32))
118
+ return intrins, img_size
119
+
120
+
121
+ def load_egobody_gt_extrinsics(seq_name, start=0, end=-1, ret_4d=True):
122
+ path = f"{EGOBODY_ROOT}/slahmr/cameras_gt/{seq_name}/cam2world.txt"
123
+ assert os.path.isfile(path)
124
+ cam2world = np.loadtxt(path).astype(np.float32) # (T, 16)
125
+ end = len(cam2world) if end < 0 else end
126
+ cam2world = torch.from_numpy(cam2world[start:end].reshape(-1, 4, 4))
127
+ if ret_4d:
128
+ return cam2world
129
+ return cam2world[:, :3, :3], cam2world[:, :3, 3]
130
+
131
+
132
+ def load_egobody_extrinsics(seq_name, use_intrins=True, start=0, end=-1):
133
+ camera_name = "cameras_intrins" if use_intrins else "cameras_default"
134
+ path = f"{EGOBODY_ROOT}/slahmr/{camera_name}/{seq_name}/cameras.npz"
135
+ assert os.path.isfile(path)
136
+ data = np.load(path)
137
+ w2c = torch.from_numpy(data["w2c"].astype(np.float32)) # (N, 4, 4)
138
+ end = len(w2c) if end < 0 else end
139
+ w2c = w2c[start:end]
140
+ c2w = torch.linalg.inv(w2c)
141
+ return c2w[:, :3, :3], c2w[:, :3, 3]
142
+
143
+
144
+ def load_egobody_meshes(seq_name, device, start=0, end=-1):
145
+ params = load_egobody_smpl_params(seq_name, start=start, end=end)
146
+ _, T = params["trans"].shape[:2]
147
+
148
+ with torch.no_grad():
149
+ gender = params["genders"][0]
150
+ body_model = load_body_model(T, "smpl", gender, device)
151
+ smpl_res = body_model(
152
+ trans=params["trans"][0].to(device),
153
+ root_orient=params["root_orient"][0].to(device),
154
+ betas=params["betas"][0].to(device),
155
+ pose_body=params["pose_body"][0].to(device),
156
+ )
157
+
158
+ res = {"joints": smpl_res.Jtr, "vertices": smpl_res.v, "faces": smpl_res.f}
159
+ return move_to(detach_all(res), "cpu")
160
+
161
+
162
+ def load_egobody_kinect2holo(seq_name, ret_4d=True):
163
+ # load the transform from kinect12 to holo
164
+ # bodies are recorded in the kinect12 frame
165
+ path = f"{EGOBODY_ROOT}/calibrations/{seq_name}/cal_trans/holo_to_kinect12.json"
166
+ with open(path, "r") as f:
167
+ kinect2holo = np.linalg.inv(np.array(json.load(f)["trans"]))
168
+ kinect2holo = torch.from_numpy(kinect2holo.astype(np.float32))
169
+ if ret_4d:
170
+ return kinect2holo
171
+ return kinect2holo[:3, :3], kinect2holo[:3, 3]
slahmr/slahmr/eval/run_eval.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import joblib
4
+ import json
5
+ import pickle
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+
11
+ import egobody_utils as eb_util
12
+ from tools import (
13
+ load_body_model,
14
+ load_results_all,
15
+ local_align_joints,
16
+ global_align_joints,
17
+ first_align_joints,
18
+ compute_accel_norm,
19
+ run_smpl,
20
+ JointRegressor,
21
+ EGOBODY_ROOT,
22
+ TDPW_ROOT,
23
+ )
24
+ from associate import associate_phalp_track_dirs
25
+
26
+
27
+ def stack_torch(x_list, dim=0):
28
+ return torch.stack(
29
+ [torch.from_numpy(x.astype(np.float32)) for x in x_list], dim=dim
30
+ )
31
+
32
+
33
+ def load_3dpw_params(seq_name, start=0, end=-1):
34
+ seq_file = f"{TDPW_ROOT}/sequenceFiles/test/{seq_name}.pkl"
35
+ with open(seq_file, "rb") as f:
36
+ data = pickle.load(f, encoding="latin1")
37
+
38
+ M = len(data["poses"])
39
+ T = len(data["poses"][0])
40
+ end = T + 1 + end if end < 0 else end
41
+ T = end - start
42
+ trans = stack_torch([x[start:end] for x in data["trans"]]) # (M, T, 3)
43
+ poses = stack_torch([x[start:end] for x in data["poses"]]) # (M, T, 72)
44
+ betas = stack_torch([x[None, :10] for x in data["betas"]]).expand(
45
+ M, T, 10
46
+ ) # (M, T, 10)
47
+ keypts2d = stack_torch([x[start:end] for x in data["poses2d"]]) # (M, T, 3, 18)
48
+ valid_cam = stack_torch(
49
+ [x[start:end] for x in data["campose_valid"]]
50
+ ).bool() # (M, T)
51
+ valid_kp = (keypts2d.reshape(M, T, -1) > 0).any(dim=-1).bool() # (M, T)
52
+ valid = valid_cam & valid_kp
53
+ genders = ["male" if x == "m" else "female" for x in data["genders"]] # (M)
54
+ return {
55
+ "root_orient": poses[..., :3],
56
+ "pose_body": poses[..., 3:],
57
+ "trans": trans,
58
+ "betas": betas,
59
+ "keypts2d": keypts2d,
60
+ "valid": valid,
61
+ "genders": genders,
62
+ }
63
+
64
+
65
+ def load_egobody_params(seq_name, start=0, end=-1):
66
+ """
67
+ returns dict of
68
+ - trans (1, T, 3)
69
+ - root_orient (1, T, 3)
70
+ - pose_body (1, T, 63)
71
+ - betas (1, T, 10)
72
+ - gender (str)
73
+ - keypts2d (1, T, J, 3)
74
+ - valid (1, T)
75
+ """
76
+ smpl_dict = eb_util.load_egobody_smpl_params(seq_name, start=start, end=end)
77
+ kps, valid = eb_util.get_egobody_keypoints(seq_name, start=start, end=end)
78
+ smpl_dict["keypts2d"] = torch.from_numpy(kps.astype(np.float32))[None]
79
+ smpl_dict["valid"] = torch.from_numpy(valid.astype(bool))[None]
80
+ return smpl_dict
81
+
82
+
83
+ def eval_result_dir(
84
+ dset_type, res_dir, out_path, joint_reg, dev_id=0, overwrite=False, debug=False
85
+ ):
86
+ if os.path.isfile(out_path) and not overwrite:
87
+ print(f"{out_path} already exists, skipping.")
88
+ return
89
+
90
+ # get the output metadata
91
+ track_file = f"{res_dir}/track_info.json"
92
+ if not os.path.isfile(track_file):
93
+ print(f"{track_file} does not exist, skipping")
94
+ return
95
+
96
+ with open(track_file, "r") as f:
97
+ track_dict = json.load(f)
98
+ start, end = track_dict["meta"]["data_interval"]
99
+ seq_name = os.path.basename(res_dir).split("-")[0]
100
+ print("EVALUATING", res_dir, seq_name, start, end)
101
+
102
+ # get the associations from PHALP tracks to GT tracks
103
+ track_info = track_dict["tracks"]
104
+ track_ids = sorted(track_info, key=lambda k: track_info[k]["index"])
105
+ print("TRACK IDS", track_ids)
106
+
107
+ if dset_type == "egobody":
108
+ # load the GT params
109
+ gt_params = load_egobody_params(seq_name, start, end)
110
+ phalp_dir = f"{EGOBODY_ROOT}/slahmr/track_preds/{seq_name}"
111
+ img_dir = eb_util.get_egobody_img_dir(seq_name)
112
+ elif dset_type == "3dpw":
113
+ gt_params = load_3dpw_params(seq_name, start, end)
114
+ phalp_dir = f"{TDPW_ROOT}/slahmr/track_gt/{seq_name}"
115
+ img_dir = f"{TDPW_ROOT}/imageFiles/{seq_name}"
116
+ else:
117
+ raise NotImplementedError
118
+
119
+ # (M, T) GT track index for each frame and each PHALP track
120
+ match_idcs = associate_phalp_track_dirs(
121
+ phalp_dir,
122
+ img_dir,
123
+ track_ids,
124
+ gt_params["keypts2d"],
125
+ start=start,
126
+ end=end,
127
+ debug=debug,
128
+ )
129
+ # M number of PHALP tracks
130
+ M = len(track_ids)
131
+
132
+ # get the GT joints
133
+ G, T = gt_params["pose_body"].shape[:2]
134
+ device = torch.device(f"cuda:{dev_id}")
135
+ gt_joints = []
136
+ for g in range(G):
137
+ body_model = load_body_model(T, "smpl", gt_params["genders"][g], device)
138
+ gt_smpl = run_smpl(
139
+ body_model,
140
+ betas=gt_params["betas"][g].to(device),
141
+ trans=gt_params["trans"][g].to(device),
142
+ root_orient=gt_params["root_orient"][g].to(device),
143
+ pose_body=gt_params["pose_body"][g].to(device),
144
+ )
145
+ gt_joints.append(joint_reg(gt_smpl["vertices"])) # (T, 15, 3)
146
+ gt_joints = torch.stack(gt_joints, dim=0)
147
+ J, D = gt_joints.shape[-2:]
148
+
149
+ # select the correct GT person for each track
150
+ gt_valid = gt_params["valid"] # (G, T)
151
+ idcs = match_idcs.clone().reshape(M, T, 1, 1).expand(-1, -1, J, D)
152
+ idcs[idcs == -1] = 0 # gather dummy for invalid matches
153
+ gt_match_joints = torch.gather(gt_joints, 0, idcs)
154
+ gt_match_valid = torch.gather(gt_valid, 0, idcs[:, :, 0, 0])
155
+ valid = gt_match_valid & (match_idcs != -1)
156
+
157
+ # use the vis_mask to get the correct data subsequence
158
+ vis_mask = torch.tensor(
159
+ [track_info[tid]["vis_mask"] for tid in track_ids]
160
+ ) # (M, T)
161
+ vis_tracks = torch.where(vis_mask.any(dim=1))[0] # (B,)
162
+ vis_idcs = torch.where(vis_mask.any(dim=0))[0]
163
+ sidx, eidx = vis_idcs.min(), vis_idcs.max() + 1
164
+ L = eidx - sidx
165
+
166
+ valid_seq = valid[vis_tracks, sidx:eidx] # (B, L)
167
+ gt_seq_joints = gt_match_joints[vis_tracks, sidx:eidx] # (B, L, *)
168
+ gt_seq_joints = gt_seq_joints[valid_seq]
169
+
170
+ if debug:
171
+ print(f"vis start {sidx}, end {eidx}, L {L}")
172
+ print("valid track matches", (match_idcs != -1).sum())
173
+ print("filtered gt joints", gt_seq_joints.shape)
174
+
175
+ # get the outputs of each phase
176
+ PHASES = ["root_fit", "smooth_fit", "motion_chunks"]
177
+ metric_names = ["ga_jmse", "fa_jmse", "pampjpe", "acc_norm"]
178
+ phase_metrics = {name: [-1 for _ in PHASE] for name in metric_names}
179
+ cur_metrics = {name: np.nan for name in metric_names}
180
+ for i, phase in enumerate(PHASES):
181
+ res_dict = load_results_all(os.path.join(res_dir, phase), device)
182
+ if res_dict is None:
183
+ print(f"PHASE {phase} did not optimize")
184
+ # update all metrics for this phase
185
+ for name in metric_names:
186
+ phase_metrics[name][i] = float(cur_metrics[name])
187
+ print(phase, phase_metrics)
188
+ continue
189
+
190
+ # (M, L, -1, 3) verts, (M, L) mask
191
+ res_verts = res_dict["vertices"][valid_seq]
192
+ res_joints = joint_reg(res_verts) # (*, 15, 3_
193
+
194
+ for name in metric_names:
195
+ if name == "acc_norm":
196
+ target = compute_accel_norm(gt_seq_joints) # (T-2, J)
197
+ pred = compute_accel_norm(res_joints)
198
+ else:
199
+ target = gt_seq_joints
200
+ if name == "pampjpe":
201
+ pred = local_align_joints(gt_seq_joints, res_joints)
202
+ if name == "ga_jmse":
203
+ pred = global_align_joints(gt_seq_joints, res_joints)
204
+ if name == "fa_jmse":
205
+ pred = first_align_joints(gt_seq_joints, res_joints)
206
+ else:
207
+ raise NotImplementedError
208
+ cur_metrics[name] = torch.linalg.norm(target - pred, dim=-1).mean()
209
+ phase_metrics[name][i] = float(cur_metrics[name])
210
+ print(phase, name, cur_metrics[name])
211
+
212
+ df_dict = {"phases": PHASES}
213
+ df_dict.update(phase_metrics)
214
+ df = pd.DataFrame.from_dict(df_dict)
215
+ df.to_csv(out_path, index=False)
216
+ print(f"saved metrics to {out_path}")
217
+
218
+
219
+ def parse_job_file(args):
220
+ subseq_names = []
221
+ with open(args.job_file, "r") as f:
222
+ for line in f.readlines():
223
+ cmd_args = line.strip().split()
224
+ seq_name, start_str, end_str = cmd_args[:3]
225
+ start = start_str.split("=")[-1]
226
+ end = end_str.split("=")[-1]
227
+ track_name = "longest-2" if args.dset_type == "3dpw" else "all"
228
+ if len(cmd_args) > 3:
229
+ track_name = cmd_args[3].split("=")[-1]
230
+ subseq_names.append(f"{seq_name}-{track_name}-{start}-{end}")
231
+ return subseq_names
232
+
233
+
234
+ def main(args):
235
+ joint_reg = JointRegressor()
236
+ out_root = args.out_root if args.out_root is not None else args.res_root
237
+ os.makedirs(out_root, exist_ok=True)
238
+
239
+ subseq_names = parse_job_file(args)
240
+ for subseq in subseq_names:
241
+ res_dir = os.path.join(args.res_root, subseq)
242
+ out_path = os.path.join(out_root, f"{subseq}.txt")
243
+ eval_result_dir(
244
+ args.dset_type,
245
+ res_dir,
246
+ out_path,
247
+ joint_reg,
248
+ overwrite=args.overwrite,
249
+ debug=args.debug,
250
+ )
251
+
252
+ metric_paths = glob.glob(f"{out_root}/[!_]*.txt")
253
+ dfs = [pd.read_csv(path) for path in metric_paths]
254
+
255
+ merged = pd.concat(dfs).groupby("phase").mean()
256
+ merged.to_csv(f"{out_root}/_final_metrics.txt")
257
+ print(merged)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ import argparse
262
+
263
+ parser = argparse.ArgumentParser()
264
+ parser.add_argument(
265
+ "-d",
266
+ "--dset_type",
267
+ required=True,
268
+ choices=["egobody", "3dpw"],
269
+ help="dataset to evaluate on, choices: (3dpw, egobody)",
270
+ )
271
+ parser.add_argument(
272
+ "-i", "--res_root", required=True, help="root directory of outputs to evaluate"
273
+ )
274
+ parser.add_argument(
275
+ "-f",
276
+ "--job_file",
277
+ required=True,
278
+ help="job file specifying the examples to run and evaluate",
279
+ )
280
+ parser.add_argument(
281
+ "-o",
282
+ "--out_root",
283
+ default=None,
284
+ help="directory to save computed metrics, default is res_root",
285
+ )
286
+ parser.add_argument("-y", "--overwrite", action="store_true")
287
+ parser.add_argument("-d", "--debug", action="store_true")
288
+ args = parser.parse_args()
289
+ main(args)
slahmr/slahmr/eval/split_3dpw.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import itertools
4
+ import joblib
5
+
6
+ from eval_3dpw import load_3dpw_params
7
+ from associate import associate_frame
8
+ from tools import TDPW_ROOT
9
+
10
+
11
+ """
12
+ Script to find the associations of ground truth 3DPW tracks with the detected PHALP tracks
13
+ Will write a job specification file to ../job_specs with which track IDs to run optimization on
14
+ """
15
+
16
+ IMG_ROOT = f"{TDPW_ROOT}/imageFiles"
17
+ SRC_DIR = f"{TDPW_ROOT}/sequenceFiles"
18
+ PHALP_DIR = f"{TDPW_ROOT}/slahmr/phalp_out/results"
19
+
20
+
21
+ def load_split_sequences(split):
22
+ assert split in ["train", "val", "test"]
23
+ split_dir = f"{SRC_DIR}/{split}"
24
+ seq_files = sorted(os.listdir(split_dir))
25
+ return [os.path.splitext(f)[0] for f in seq_files]
26
+
27
+
28
+ def select_phalp_tracks(seq_name, split, start, end, debug=False):
29
+ """
30
+ Select the best phalp track for each GT person for each frame.
31
+ Returns all phalp tracks that match GT over sequence
32
+ """
33
+ phalp_file = f"{PHALP_DIR}/{seq_name}.pkl"
34
+ track_data = joblib.load(phalp_file)
35
+ img_names = sorted(track_data.keys())
36
+ sel_imgs = img_names[start:end]
37
+
38
+ gt_params = load_3dpw_params(f"{SRC_DIR}/{split}/{seq_name}.pkl", start, end)
39
+ gt_kps = gt_params["keypts2d"]
40
+ G, T = gt_kps.shape[:2] # G num people, T num frames
41
+ assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
42
+
43
+ track_ids = set()
44
+ for frame in sel_imgs:
45
+ frame_data = track_data[frame]
46
+ for tid in frame_data["tracked_ids"]:
47
+ track_ids.add(str(tid))
48
+ track_ids = list(track_ids)
49
+ M = len(track_ids)
50
+ track_idcs = {tid: m for m, tid in enumerate(track_ids)}
51
+
52
+ # get the best matching PHALP track for each GT person
53
+ sel_tracks = set()
54
+ for t, frame_name in enumerate(sel_imgs):
55
+ frame_data = track_data[frame_name]
56
+ for g in range(G):
57
+ kp_gt = gt_kps[g, t].T.numpy() # (18, 3)
58
+ # get the best track ID for the GT person
59
+ tid = associate_frame(frame_data, kp_gt, track_ids, debug=debug)
60
+ if tid == -1:
61
+ continue
62
+ sel_tracks.add(int(tid))
63
+ return list(sel_tracks)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ import argparse
68
+
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--seq_len", type=int, default=100)
71
+ parser.add_argument(
72
+ "--split", default="test", choices=["train", "val", "test", "all"]
73
+ )
74
+ parser.add_argument("--prefix", default="3dpw")
75
+ args = parser.parse_args()
76
+
77
+ seqs = load_split_sequences(args.split)
78
+
79
+ job_arg_strs = []
80
+ for seq in seqs:
81
+ num_imgs = len(glob.glob(f"{IMG_ROOT}/{seq}/*.jpg"))
82
+ splits = list(range(0, num_imgs, args.seq_len))
83
+ splits[-1] = num_imgs # just add the remainder to the last job
84
+ for start, end in zip(splits[:-1], splits[1:]):
85
+ sel_tracks = select_phalp_tracks(seq, args.split, start, end)
86
+ if len(sel_tracks) < 1:
87
+ continue
88
+ track_str = "-".join([f"{tid:03d}" for tid in sel_tracks])
89
+ arg_str = (
90
+ f"{seq} data.start_idx={start} data.end_idx={end} "
91
+ f"data.track_ids={track_str}"
92
+ )
93
+ print(arg_str)
94
+ job_arg_strs.append(arg_str)
95
+
96
+ with open(
97
+ f"../job_specs/{args.prefix}_{args.split}_len_{args.seq_len}.txt", "w"
98
+ ) as f:
99
+ f.write("\n".join(job_arg_strs))
slahmr/slahmr/eval/split_egobody.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import joblib
4
+ import itertools
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from associate import associate_frame
10
+ from tools import EGOBODY_ROOT
11
+
12
+
13
+ """
14
+ Script to find the associations of ground truth Egobody tracks with the detected PHALP tracks
15
+ Will write a job specification file to ../job_specs with which track IDs to run optimization on
16
+ """
17
+
18
+ IMG_ROOT = f"{EGOBODY_ROOT}/egocentric_color"
19
+ PHALP_DIR = f"{EGOBODY_ROOT}/slahmr/phalp_out/results"
20
+
21
+
22
+ def load_split_sequences(split):
23
+ split_file = "{EGOBODY_ROOT}/data_splits.csv"
24
+ df = pd.read_csv(split_file)
25
+ if split not in df.columns:
26
+ print(f"{split} not in {split_file}")
27
+ return []
28
+ return df[split].dropna().tolist()
29
+
30
+
31
+ def get_egobody_keypoints(img_dir, start, end):
32
+ kp_file = f"{img_dir}/keypoints.npz"
33
+ valid_file = f"{img_dir}/valid_frame.npz"
34
+ img_paths = sorted(glob.glob(f"{img_dir}/PV/*.jpg"))[start:end]
35
+ img_names = [os.path.basename(x) for x in img_paths]
36
+
37
+ kp_dict = {}
38
+ valid_dict = {}
39
+ kp_data = np.load(kp_file)
40
+ valid_data = np.load(valid_file)
41
+
42
+ zeros = np.zeros_like(kp_data["keypoints"][0])
43
+ for img_path, kps in zip(kp_data["imgname"], kp_data["keypoints"]):
44
+ img_name = os.path.basename(img_path)
45
+ kp_dict[img_name] = kps
46
+
47
+ for img_path, valid in zip(valid_data["imgname"], valid_data["valid"]):
48
+ img_name = os.path.basename(img_path)
49
+ valid_dict[img_name] = valid
50
+
51
+ kps = np.stack([kp_dict.get(name, zeros) for name in img_names], axis=0)
52
+ valid = np.stack([valid_dict.get(name, False) for name in img_names], axis=0)
53
+ return kps, valid
54
+
55
+
56
+ def select_phalp_tracks(seq_name, img_dir, start, end, debug=False):
57
+ """
58
+ Get the best phalp track for each GT person for each frame
59
+ Returns all phalp tracks that match GT over sequence
60
+ """
61
+ phalp_file = f"{PHALP_DIR}/{seq_name}.pkl"
62
+ track_data = joblib.load(phalp_file)
63
+ img_names = sorted(track_data.keys())
64
+ sel_imgs = img_names[start:end]
65
+
66
+ kps_all, valid = get_egobody_keypoints(img_dir, start, end)
67
+ T = len(kps_all)
68
+ assert len(sel_imgs) == T, f"found {len(sel_imgs)} frames, expected {T}"
69
+
70
+ track_ids = set()
71
+ for frame in sel_imgs:
72
+ frame_data = track_data[frame]
73
+ for tid in frame_data["tracked_ids"]:
74
+ track_ids.add(str(tid))
75
+ track_ids = list(track_ids)
76
+ M = len(track_ids)
77
+ track_idcs = {tid: m for m, tid in enumerate(track_ids)}
78
+
79
+ # get the best matching PHALP track for each GT person
80
+ sel_tracks = set()
81
+ for t, frame_name in enumerate(sel_imgs):
82
+ frame_data = track_data[frame_name]
83
+ # get the best track ID for the GT person
84
+ tid = associate_frame(frame_data, kps_all[t], track_ids, debug=debug)
85
+ if tid == -1:
86
+ continue
87
+ sel_tracks.add(int(tid))
88
+ return list(sel_tracks)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ import argparse
93
+
94
+ parser = argparse.ArgumentParser()
95
+ parser.add_argument("--seq_len", type=int, default=100)
96
+ parser.add_argument("--split", default="val", choices=["train", "val", "test"])
97
+ parser.add_argument("--prefix", default="ego")
98
+ args = parser.parse_args()
99
+
100
+ seqs = load_split_sequences(args.split)
101
+
102
+ job_arg_strs = []
103
+ for seq in seqs:
104
+ img_dir = glob.glob(f"{IMG_ROOT}/{seq}/**/")[0]
105
+ num_imgs = len(glob.glob(f"{img_dir}/PV/*.jpg"))
106
+ splits = list(range(0, num_imgs, args.seq_len))
107
+ splits[-1] = num_imgs # just add the remainder to the last job
108
+ for start, end in zip(splits[:-1], splits[1:]):
109
+ sel_tracks = select_phalp_tracks(seq, img_dir, start, end)
110
+ if len(sel_tracks) < 1:
111
+ continue
112
+ track_str = "-".join([f"{tid:03d}" for tid in sel_tracks])
113
+ arg_str = (
114
+ f"{seq} data.start_idx={start} data.end_idx={end} "
115
+ f"data.track_ids={track_str}"
116
+ )
117
+ print(arg_str)
118
+ job_arg_strs.append(arg_str)
119
+
120
+ with open(
121
+ f"../job_specs/{args.prefix}_{args.split}_len_{args.seq_len}_tracks.txt", "w"
122
+ ) as f:
123
+ f.write("\n".join(job_arg_strs))
slahmr/slahmr/eval/tools.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import json
5
+ import joblib
6
+ import numpy as np
7
+ import smplx
8
+ import torch
9
+
10
+ from util.loaders import load_smpl_body_model
11
+ from util.tensor import move_to, detach_all, to_torch
12
+ from optim.output import load_result, get_results_paths
13
+ from geometry.pcl import align_pcl
14
+ from geometry.rotation import batch_rodrigues
15
+
16
+ BASE_DIR = os.path.abspath(f"{__file__}/../../../")
17
+ JOINT_REG_PATH = f"{BASE_DIR}/_DATA/body_models/J_regressor_h36m.npy"
18
+
19
+
20
+ # XXX: Sorry, need to change this yourself
21
+ EGOBODY_ROOT = "/path/to/egobody"
22
+ TDPW_ROOT = "/path/to/3DPW"
23
+
24
+
25
+ class JointRegressor(object):
26
+ def __init__(self):
27
+ # (17, 6890)
28
+ R17 = torch.from_numpy(np.load(JOINT_REG_PATH).astype(np.float32))
29
+ # (14,) adding the root, but will omit
30
+ joint_map_h36m = torch.tensor([6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10])
31
+ self.regressor = R17[joint_map_h36m] # (14, 6890)
32
+
33
+ def to(self, device):
34
+ self.regressor = self.regressor.to(device)
35
+
36
+ def __call__(self, verts):
37
+ """
38
+ NOTE: RETURNS ROOT AS WELL
39
+ :param verts (*, V, 3)
40
+ returns (*, J, 3) 14 standard evaluation joints
41
+ """
42
+ return torch.einsum("nv,...vd->...nd", self.regressor, verts) # (..., 14, 3)
43
+
44
+
45
+ def compute_accel_norm(joints):
46
+ """
47
+ :param joints (T, J, 3)
48
+ """
49
+ vel = joints[1:] - joints[:-1] # (T-1, J, 3)
50
+ acc = vel[1:] - vel[:-1] # (T-2, J, 3)
51
+ return torch.linalg.norm(acc, dim=-1)
52
+
53
+
54
+ def global_align_joints(gt_joints, pred_joints):
55
+ """
56
+ :param gt_joints (T, J, 3)
57
+ :param pred_joints (T, J, 3)
58
+ """
59
+ s_glob, R_glob, t_glob = align_pcl(
60
+ gt_joints.reshape(-1, 3), pred_joints.reshape(-1, 3)
61
+ )
62
+ pred_glob = (
63
+ s_glob * torch.einsum("ij,tnj->tni", R_glob, pred_joints) + t_glob[None, None]
64
+ )
65
+ return pred_glob
66
+
67
+
68
+ def first_align_joints(gt_joints, pred_joints):
69
+ """
70
+ align the first two frames
71
+ :param gt_joints (T, J, 3)
72
+ :param pred_joints (T, J, 3)
73
+ """
74
+ # (1, 1), (1, 3, 3), (1, 3)
75
+ s_first, R_first, t_first = align_pcl(
76
+ gt_joints[:2].reshape(1, -1, 3), pred_joints[:2].reshape(1, -1, 3)
77
+ )
78
+ pred_first = (
79
+ s_first * torch.einsum("tij,tnj->tni", R_first, pred_joints) + t_first[:, None]
80
+ )
81
+ return pred_first
82
+
83
+
84
+ def local_align_joints(gt_joints, pred_joints):
85
+ """
86
+ :param gt_joints (T, J, 3)
87
+ :param pred_joints (T, J, 3)
88
+ """
89
+ s_loc, R_loc, t_loc = align_pcl(gt_joints, pred_joints)
90
+ pred_loc = (
91
+ s_loc[:, None] * torch.einsum("tij,tnj->tni", R_loc, pred_joints)
92
+ + t_loc[:, None]
93
+ )
94
+ return pred_loc
95
+
96
+
97
+ def load_body_model(batch_size, model_type, gender, device):
98
+ assert model_type in ["smpl", "smplh"]
99
+ if model_type == "smpl":
100
+ num_betas = 10
101
+ ext = "pkl"
102
+ use_vtx_selector = False
103
+ else:
104
+ num_betas = 16
105
+ ext = "npz"
106
+ use_vtx_selector = True
107
+
108
+ smpl_path = f"{BASE_DIR}/body_models/{model_type}/{gender}/model.{ext}"
109
+ body_model, fit_gender = load_smpl_body_model(
110
+ smpl_path,
111
+ batch_size,
112
+ num_betas,
113
+ model_type=model_type,
114
+ use_vtx_selector=use_vtx_selector,
115
+ device=device,
116
+ )
117
+ return body_model
118
+
119
+
120
+ def run_smpl(body_model, *args, **kwargs):
121
+ with torch.no_grad():
122
+ results = body_model(*args, **kwargs)
123
+ return {
124
+ "joints": results.Jtr.detach().cpu(),
125
+ "vertices": results.v.detach().cpu(),
126
+ "faces": results.f.detach().cpu(),
127
+ }
128
+
129
+
130
+ def run_smpl_batch(body_model, device, **kwargs):
131
+ model_kwargs = {}
132
+ B = body_model.bm.batch_size
133
+ kwarg_shape = (B,)
134
+ for k, v in kwargs.items():
135
+ kwarg_shape = v.shape[:-1]
136
+ model_kwargs[k] = v.reshape(B, v.shape[-1]).to(device)
137
+ res_flat = run_smpl(body_model, **model_kwargs)
138
+ res = {}
139
+ for k, v in res_flat.items():
140
+ sh = v.shape
141
+ if sh[0] == B:
142
+ v = v.reshape(*kwarg_shape, *sh[1:])
143
+ res[k] = v
144
+ return res
145
+
146
+
147
+ def cat_dicts(dict_list, dim=0):
148
+ """
149
+ concatenate lists of dict of tensors
150
+ """
151
+ keys = set(dict_list[0].keys())
152
+ assert all(keys == set(d.keys()) for d in dict_list)
153
+ return {k: torch.stack([d[k] for d in dict_list], dim=dim) for k in keys}
154
+
155
+
156
+ def load_results_all(phase_dir, device):
157
+ """
158
+ Load all the reconstructed tracks during optimization
159
+ """
160
+ res_path_dict = get_results_paths(phase_dir)
161
+ max_iter = max(res_path_dict.keys())
162
+ if int(max_iter) < 20:
163
+ print("max_iter", max_iter)
164
+ return None
165
+
166
+ res = load_result(res_path_dict[max_iter])["world"]
167
+ # results is dict with (B, T, *) tensors
168
+ trans = res["trans"]
169
+ B, T, _ = trans.shape
170
+ root_orient = res["root_orient"]
171
+ pose_body = res["pose_body"]
172
+ betas = res["betas"].reshape(B, 1, -1).expand(B, T, -1)
173
+ body_model = load_body_model(B * T, "smplh", "neutral", device)
174
+ return run_smpl_batch(
175
+ body_model,
176
+ device,
177
+ trans=trans,
178
+ root_orient=root_orient,
179
+ betas=betas,
180
+ pose_body=pose_body,
181
+ )
slahmr/slahmr/geometry/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import camera
2
+ from . import mesh
3
+ from . import pcl
4
+ from . import plane
5
+ from . import rotation
slahmr/slahmr/geometry/camera.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def perspective_projection(
6
+ points, focal_length, camera_center, rotation=None, translation=None
7
+ ):
8
+ """
9
+ Adapted from https://github.com/mkocabas/VIBE/blob/master/lib/models/spin.py
10
+ This function computes the perspective projection of a set of points.
11
+ Input:
12
+ points (bs, N, 3): 3D points
13
+ focal_length (bs, 2): Focal length
14
+ camera_center (bs, 2): Camera center
15
+ rotation (bs, 3, 3): OPTIONAL Camera rotation
16
+ translation (bs, 3): OPTIONAL Camera translation
17
+ """
18
+ batch_size = points.shape[0]
19
+ K = torch.zeros([batch_size, 3, 3], device=points.device)
20
+ K[:, 0, 0] = focal_length[:, 0]
21
+ K[:, 1, 1] = focal_length[:, 1]
22
+ K[:, 2, 2] = 1.0
23
+ K[:, :-1, -1] = camera_center
24
+
25
+ if rotation is not None and translation is not None:
26
+ # Transform points
27
+ points = torch.einsum("bij,bkj->bki", rotation, points)
28
+ points = points + translation.unsqueeze(1)
29
+
30
+ # Apply perspective distortion
31
+ projected_points = points / points[..., 2:3]
32
+
33
+ # Apply camera intrinsics
34
+ projected_points = torch.einsum("bij,bkj->bki", K, projected_points)
35
+
36
+ return projected_points[:, :, :-1]
37
+
38
+
39
+ def reproject(points3d, cam_R, cam_t, cam_f, cam_center):
40
+ """
41
+ reproject points3d into the scene cameras
42
+ :param points3d (B, T, N, 3)
43
+ :param cam_R (B, T, 3, 3)
44
+ :param cam_t (B, T, 3)
45
+ :param cam_f (T, 2)
46
+ :param cam_center (T, 2)
47
+ """
48
+ B, T, N, _ = points3d.shape
49
+ points3d = torch.einsum("btij,btnj->btni", cam_R, points3d)
50
+ points3d = points3d + cam_t[..., None, :] # (B, T, N, 3)
51
+ points2d = points3d[..., :2] / points3d[..., 2:3]
52
+ points2d = cam_f[None, :, None] * points2d + cam_center[None, :, None]
53
+ return points2d
54
+
55
+
56
+ def focal2fov(focal, R):
57
+ """
58
+ :param focal, focal length
59
+ :param R, either W / 2 or H / 2
60
+ """
61
+ return 2 * np.arctan(R / focal)
62
+
63
+
64
+ def fov2focal(fov, R):
65
+ """
66
+ :param fov, field of view in radians
67
+ :param R, either W / 2 or H / 2
68
+ """
69
+ return R / np.tan(fov / 2)
70
+
71
+
72
+ def compute_lookat_box(bb_min, bb_max, intrins):
73
+ """
74
+ The center and distance to a scene with bb_min, bb_max
75
+ to place a camera with given intrinsics
76
+ :param bb_min (3,)
77
+ :param bb_max (3,)
78
+ :param intrinsics, (fx, fy, cx, cy) of camera
79
+ :param view_angle (optional) viewing angle in radians (elevation)
80
+ """
81
+ fx, fy, cx, cy = intrins
82
+ bb_min, bb_max = torch.tensor(bb_min), torch.tensor(bb_max)
83
+ center = 0.5 * (bb_min + bb_max)
84
+ size = torch.linalg.norm(bb_max - bb_min)
85
+ cam_dist = np.sqrt(fx**2 + fy**2) / np.sqrt(cx**2 + cy**2)
86
+ cam_dist = 0.75 * size * cam_dist
87
+ return center, cam_dist
88
+
89
+
90
+ def lookat_origin(cam_dist, view_angle=-np.pi / 6):
91
+ """
92
+ :param cam_dist (float)
93
+ :param view_angle (float)
94
+ """
95
+ cam_dist = np.abs(cam_dist)
96
+ view_angle = np.abs(view_angle)
97
+ pos = cam_dist * torch.tensor([0, np.sin(view_angle), np.cos(view_angle)])
98
+ rot = rotx(view_angle)
99
+ return rot, pos
100
+
101
+
102
+ def lookat_matrix(source_pos, target_pos, up):
103
+ """
104
+ IMPORTANT: USES RIGHT UP BACK XYZ CONVENTION
105
+ :param source_pos (*, 3)
106
+ :param target_pos (*, 3)
107
+ :param up (3,)
108
+ """
109
+ *dims, _ = source_pos.shape
110
+ up = up.reshape(*(1,) * len(dims), 3)
111
+ up = up / torch.linalg.norm(up, dim=-1, keepdim=True)
112
+ back = normalize(target_pos - source_pos)
113
+ right = normalize(torch.linalg.cross(up, back))
114
+ up = normalize(torch.linalg.cross(back, right))
115
+ R = torch.stack([right, up, back], dim=-1)
116
+ return make_4x4_pose(R, source_pos)
117
+
118
+
119
+ def normalize(x):
120
+ return x / torch.linalg.norm(x, dim=-1, keepdim=True)
121
+
122
+
123
+ def invert_camera(R, t):
124
+ """
125
+ :param R (*, 3, 3)
126
+ :param t (*, 3)
127
+ returns Ri (*, 3, 3), ti (*, 3)
128
+ """
129
+ R, t = torch.tensor(R), torch.tensor(t)
130
+ Ri = R.transpose(-1, -2)
131
+ ti = -torch.einsum("...ij,...j->...i", Ri, t)
132
+ return Ri, ti
133
+
134
+
135
+ def compose_cameras(R1, t1, R2, t2):
136
+ """
137
+ composes [R1, t1] and [R2, t2]
138
+ :param R1 (*, 3, 3)
139
+ :param t1 (*, 3)
140
+ :param R2 (*, 3, 3)
141
+ :param t2 (*, 3)
142
+ """
143
+ R = torch.einsum("...ij,...jk->...ik", R1, R2)
144
+ t = t1 + torch.einsum("...ij,...j->...i", R1, t2)
145
+ return R, t
146
+
147
+
148
+ def matmul_nd(A, x):
149
+ """
150
+ multiply batch matrix A to batch nd tensors
151
+ :param A (B, m, n)
152
+ :param x (B, *dims, m)
153
+ """
154
+ B, m, n = A.shape
155
+ assert len(A) == len(x)
156
+ assert x.shape[-1] == m
157
+ B, *dims, _ = x.shape
158
+ return torch.matmul(A.reshape(B, *(1,) * len(dims), m, n), x[..., None])[..., 0]
159
+
160
+
161
+ def view_matrix(z, up, pos):
162
+ """
163
+ :param z (*, 3) up (*, 3) pos (*, 3)
164
+ returns (*, 4, 4)
165
+ """
166
+ *dims, _ = z.shape
167
+ x = normalize(torch.linalg.cross(up, z))
168
+ y = normalize(torch.linalg.cross(z, x))
169
+ bottom = (
170
+ torch.tensor([0, 0, 0, 1], dtype=torch.float32)
171
+ .reshape(*(1,) * len(dims), 1, 4)
172
+ .expand(*dims, 1, 4)
173
+ )
174
+
175
+ return torch.cat([torch.stack([x, y, z, pos], dim=-1), bottom], dim=-2)
176
+
177
+
178
+ def average_pose(poses):
179
+ """
180
+ :param poses (N, 4, 4)
181
+ returns average pose (4, 4)
182
+ """
183
+ center = poses[:, :3, 3].mean(0)
184
+ up = normalize(poses[:, :3, 1].sum(0))
185
+ z = normalize(poses[:, :3, 2].sum(0))
186
+ return view_matrix(z, up, center)
187
+
188
+
189
+ def project_so3(M, eps=1e-4):
190
+ """
191
+ :param M (N, *, 3, 3)
192
+ """
193
+ N, *dims, _, _ = M.shape
194
+ M = M * (1 + torch.rand(N, *dims, 1, 3, device=M.device))
195
+ U, D, Vt = torch.linalg.svd(M) # (N, *, 3, 3), (N, *, 3), (N, *, 3, 3)
196
+ detuvt = torch.linalg.det(torch.matmul(U, Vt)) # (N, *)
197
+ S = torch.cat(
198
+ [torch.ones(N, *dims, 2, device=M.device), detuvt[..., None]], dim=-1
199
+ ) # (N, *, 3)
200
+ return torch.matmul(U, torch.matmul(torch.diag_embed(S), Vt))
201
+
202
+
203
+ def make_translation(t):
204
+ return make_4x4_pose(torch.eye(3), t)
205
+
206
+
207
+ def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
208
+ Rx = rotx(rx)
209
+ Ry = roty(ry)
210
+ Rz = rotz(rz)
211
+ if order == "xyz":
212
+ R = Rz @ Ry @ Rx
213
+ elif order == "xzy":
214
+ R = Ry @ Rz @ Rx
215
+ elif order == "yxz":
216
+ R = Rz @ Rx @ Ry
217
+ elif order == "yzx":
218
+ R = Rx @ Rz @ Ry
219
+ elif order == "zyx":
220
+ R = Rx @ Ry @ Rz
221
+ elif order == "zxy":
222
+ R = Ry @ Rx @ Rz
223
+ return make_4x4_pose(R, torch.zeros(3))
224
+
225
+
226
+ def make_4x4_pose(R, t):
227
+ """
228
+ :param R (*, 3, 3)
229
+ :param t (*, 3)
230
+ return (*, 4, 4)
231
+ """
232
+ dims = R.shape[:-2]
233
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
234
+ bottom = (
235
+ torch.tensor([0, 0, 0, 1], device=R.device)
236
+ .reshape(*(1,) * len(dims), 1, 4)
237
+ .expand(*dims, 1, 4)
238
+ )
239
+ return torch.cat([pose_3x4, bottom], dim=-2)
240
+
241
+
242
+ def normalize(x):
243
+ return x / torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True))
244
+
245
+
246
+ def rotx(theta):
247
+ return torch.tensor(
248
+ [
249
+ [1, 0, 0],
250
+ [0, np.cos(theta), -np.sin(theta)],
251
+ [0, np.sin(theta), np.cos(theta)],
252
+ ],
253
+ dtype=torch.float32,
254
+ )
255
+
256
+
257
+ def roty(theta):
258
+ return torch.tensor(
259
+ [
260
+ [np.cos(theta), 0, np.sin(theta)],
261
+ [0, 1, 0],
262
+ [-np.sin(theta), 0, np.cos(theta)],
263
+ ],
264
+ dtype=torch.float32,
265
+ )
266
+
267
+
268
+ def rotz(theta):
269
+ return torch.tensor(
270
+ [
271
+ [np.cos(theta), -np.sin(theta), 0],
272
+ [np.sin(theta), np.cos(theta), 0],
273
+ [0, 0, 1],
274
+ ],
275
+ dtype=torch.float32,
276
+ )
277
+
278
+
279
+ def relative_pose_c2w(Rwc1, Rwc2, twc1, twc2):
280
+ """
281
+ compute relative pose from cam 1 to cam 2 given c2w pose matrices
282
+ :param Rwc1, Rwc2 (N, 3, 3) cam1, cam2 to world rotations
283
+ :param twc1, twc2 (N, 3) cam1, cam2 to world translations
284
+ returns R21 (N, 3, 3) t21 (N, 3)
285
+ """
286
+ twc1 = twc1.view(-1, 3, 1)
287
+ twc2 = twc2.view(-1, 3, 1)
288
+ Rc2w = Rwc2.transpose(-1, -2) # world to c2
289
+ tc2w = -torch.matmul(Rc2w, twc2)
290
+ Rc2c1 = torch.matmul(Rc2w, Rwc1)
291
+ tc2c1 = tc2w + torch.matmul(Rc2w, twc1)
292
+ return Rc2c1, tc2c1[..., 0]
293
+
294
+
295
+ def relative_pose_w2c(Rc1w, Rc2w, tc1w, tc2w):
296
+ """
297
+ compute relative pose from cam 1 to cam 2 given w2c camera matrices
298
+ :param Rc1w, Rc2w (N, 3, 3) world to cam1, cam2 rotations
299
+ :param tc1w, tc2w (N, 3) world to cam1, cam2 translations
300
+ """
301
+ tc1w = tc1w.view(-1, 3, 1)
302
+ tc2w = tc2w.view(-1, 3, 1)
303
+ # we keep the world to cam transforms
304
+ Rwc1 = Rc1w.transpose(-1, -2) # c1 to world
305
+ twc1 = -torch.matmul(Rwc1, tc1w)
306
+ Rc2c1 = torch.matmul(Rc2w, Rwc1) # c1 to c2
307
+ tc2c1 = tc2w + torch.matmul(Rc2w, twc1)
308
+ return Rc2c1, tc2c1[..., 0]
309
+
310
+
311
+ def project(xyz_c, center, focal, eps=1e-5):
312
+ """
313
+ :param xyz_c (*, 3) 3d point in camera coordinates
314
+ :param focal (1)
315
+ :param center (*, 2)
316
+ return (*, 2)
317
+ """
318
+ return focal * xyz_c[..., :2] / (xyz_c[..., 2:3] + eps) + center # (N, *, 2)
319
+
320
+
321
+ def convert_yup(xyz):
322
+ """
323
+ converts points in x right y down z forward to x right y up z back
324
+ :param xyz (*, 3)
325
+ """
326
+ x, y, z = torch.split(xyz[..., :3], 1, dim=-1)
327
+ return torch.cat([x, -y, -z], dim=-1)
328
+
329
+
330
+ def inv_project(uv, z, center, focal, yup=True):
331
+ """
332
+ :param uv (*, 2)
333
+ :param z (*, 1)
334
+ :param center (*, 2)
335
+ :param focal (1)
336
+ :returns (*, 3)
337
+ """
338
+ uv = uv - center
339
+ if yup:
340
+ return z * torch.cat(
341
+ [uv[..., :1] / focal, -uv[..., 1:2] / focal, -torch.ones_like(uv[..., :1])],
342
+ dim=-1,
343
+ ) # (N, *, 3)
344
+
345
+ return z * torch.cat(
346
+ [uv / focal, torch.ones_like(uv[..., :1])],
347
+ dim=-1,
348
+ ) # (N, *, 3)
slahmr/slahmr/geometry/mesh.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import trimesh
4
+
5
+
6
+ def get_mesh_bb(mesh):
7
+ """
8
+ :param mesh - trimesh mesh object
9
+ returns bb_min (3), bb_max (3)
10
+ """
11
+ bb_min = mesh.vertices.max(axis=0)
12
+ bb_max = mesh.vertices.min(axis=0)
13
+ return bb_min, bb_max
14
+
15
+
16
+ def get_scene_bb(meshes):
17
+ """
18
+ :param mesh_seqs - (potentially nested) list of trimesh objects
19
+ returns bb_min (3), bb_max (3)
20
+ """
21
+ if isinstance(meshes, trimesh.Trimesh):
22
+ return get_mesh_bb(meshes)
23
+
24
+ bb_mins, bb_maxs = zip(*[get_scene_bb(mesh) for mesh in meshes])
25
+ bb_mins = np.stack(bb_mins, axis=0)
26
+ bb_maxs = np.stack(bb_maxs, axis=0)
27
+ return bb_mins.min(axis=0), bb_maxs.max(axis=0)
28
+
29
+
30
+ def make_batch_mesh(verts, faces, colors):
31
+ """
32
+ convenience function to make batch of meshes
33
+ meshs have same faces in batch, verts have same color in mesh
34
+ :param verts (B, V, 3)
35
+ :param faces (F, 3)
36
+ :param colors (B, 3)
37
+ """
38
+ B, V, _ = verts.shape
39
+ return [make_mesh(verts[b], faces, colors[b, None].expand(V, -1)) for b in range(B)]
40
+
41
+
42
+ def make_mesh(verts, faces, colors=None, yup=True):
43
+ """
44
+ create a trimesh object for the faces and vertices
45
+ :param verts (V, 3) tensor
46
+ :param faces (F, 3) tensor
47
+ :param colors (optional) (V, 3) tensor
48
+ :param yup (optional bool) whether or not to save with Y up
49
+ """
50
+ verts = verts.detach().cpu().numpy()
51
+ faces = faces.detach().cpu().numpy()
52
+ if yup:
53
+ verts = np.array([1, -1, -1])[None, :] * verts
54
+ if colors is None:
55
+ colors = np.ones_like(verts) * 0.5
56
+ else:
57
+ colors = colors.detach().cpu().numpy()
58
+ return trimesh.Trimesh(
59
+ vertices=verts, faces=faces, vertex_colors=colors, process=False
60
+ )
61
+
62
+
63
+ def save_mesh_scenes(out_dir, scenes):
64
+ """
65
+ :param scenes, list of scenes (list of meshes)
66
+ """
67
+ assert isinstance(scenes, list)
68
+ assert isinstance(scenes[0], list)
69
+ B = len(scenes[0])
70
+ if B == 1:
71
+ save_meshes_to_obj(out_dir, [x[0] for x in scenes])
72
+ else:
73
+ save_scenes_to_glb(out_dir, scenes)
74
+
75
+
76
+ def save_scenes_to_glb(out_dir, scenes):
77
+ """
78
+ Saves a list of scenes (list of meshes) each to glb files
79
+ """
80
+ os.makedirs(out_dir, exist_ok=True)
81
+ for t, meshes in enumerate(scenes):
82
+ save_meshes_to_glb(f"{out_dir}/scene_{t:03d}.glb", meshes)
83
+
84
+
85
+ def save_meshes_to_glb(path, meshes, names=None):
86
+ """
87
+ put trimesh meshes in a scene and export to glb
88
+ """
89
+ if names is not None:
90
+ assert len(meshes) == len(names)
91
+
92
+ scene = trimesh.Scene()
93
+ for i, mesh in enumerate(meshes):
94
+ name = f"mesh_{i:03d}" if names is None else names[i]
95
+ scene.add_geometry(mesh, node_name=name)
96
+
97
+ with open(path, "wb") as f:
98
+ f.write(trimesh.exchange.gltf.export_glb(scene, include_normals=True))
99
+
100
+
101
+ def save_meshes_to_obj(out_dir, meshes, names=None):
102
+ if names is not None:
103
+ assert len(meshes) == len(names)
104
+
105
+ os.makedirs(out_dir, exist_ok=True)
106
+ for i, mesh in enumerate(meshes):
107
+ name = f"mesh_{i:03d}" if names is None else names[i]
108
+ path = os.path.join(out_dir, f"{name}.obj")
109
+ with open(path, "w") as f:
110
+ mesh.export(f, file_type="obj", include_color=False, include_normals=True)
slahmr/slahmr/geometry/pcl.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def read_pcl_tensor(path):
6
+ pcl_np = read_pcl(path)
7
+ return torch.from_numpy(pcl_np)
8
+
9
+
10
+ def align_pcl(Y, X, weight=None, fixed_scale=False):
11
+ """align similarity transform to align X with Y using umeyama method
12
+ X' = s * R * X + t is aligned with Y
13
+ :param Y (*, N, 3) first trajectory
14
+ :param X (*, N, 3) second trajectory
15
+ :param weight (*, N, 1) optional weight of valid correspondences
16
+ :returns s (*, 1), R (*, 3, 3), t (*, 3)
17
+ """
18
+ *dims, N, _ = Y.shape
19
+ N = torch.ones(*dims, 1, 1) * N
20
+
21
+ if weight is not None:
22
+ Y = Y * weight
23
+ X = X * weight
24
+ N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1)
25
+
26
+ # subtract mean
27
+ my = Y.sum(dim=-2) / N[..., 0] # (*, 3)
28
+ mx = X.sum(dim=-2) / N[..., 0]
29
+ y0 = Y - my[..., None, :] # (*, N, 3)
30
+ x0 = X - mx[..., None, :]
31
+
32
+ if weight is not None:
33
+ y0 = y0 * weight
34
+ x0 = x0 * weight
35
+
36
+ # correlation
37
+ C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
38
+ U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
39
+
40
+ S = torch.eye(3).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)
41
+ neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
42
+ S[neg, 2, 2] = -1
43
+
44
+ R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
45
+
46
+ D = torch.diag_embed(D) # (*, 3, 3)
47
+ if fixed_scale:
48
+ s = torch.ones(*dims, 1, device=Y.device, dtype=torch.float32)
49
+ else:
50
+ var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
51
+ s = (
52
+ torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
53
+ dim=-1, keepdim=True
54
+ )
55
+ / var[..., 0]
56
+ ) # (*, 1)
57
+
58
+ t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
59
+
60
+ return s, R, t
slahmr/slahmr/geometry/plane.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def fit_plane(points):
5
+ """
6
+ :param points (*, N, 3)
7
+ returns (*, 3) plane parameters (returns in normal * offset format)
8
+ """
9
+ *dims, N, D = points.shape
10
+ mean = points.mean(dim=-2, keepdim=True)
11
+ # (*, N, D), (*, D), (*, D, D)
12
+ U, S, Vh = torch.linalg.svd(points - mean)
13
+ normal = Vh[..., -1, :] # (*, D)
14
+ offset = torch.einsum("...ij,...j->...i", points, normal) # (*, N)
15
+ offset = offset.mean(dim=-1, keepdim=True)
16
+ return torch.cat([normal, offset], dim=-1)
17
+
18
+
19
+ def get_plane_transform(up, ground_plane=None, xyz_orig=None):
20
+ """
21
+ get R, t rigid transform from plane and desired origin
22
+ :param up (3,) up vector of coordinate frame
23
+ :param ground_plane (4) (a, b, c, d) where a,b,c is the normal
24
+ :param xyz_orig (3) desired origin
25
+ """
26
+ R = torch.eye(3)
27
+ t = torch.zeros(3)
28
+ if ground_plane is None:
29
+ return R, t
30
+
31
+ # compute transform between world up vector and passed in floor
32
+ ground_plane = torch.as_tensor(ground_plane)
33
+ ground_plane = torch.sign(ground_plane[3]) * ground_plane
34
+
35
+ normal = ground_plane[:3]
36
+ normal = normal / torch.linalg.norm(normal)
37
+ v = torch.linalg.cross(up, normal)
38
+ ang_sin = torch.linalg.norm(v)
39
+ ang_cos = up.dot(normal)
40
+ skew_v = torch.as_tensor([[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]])
41
+ R = torch.eye(3) + skew_v + (skew_v @ skew_v) * ((1.0 - ang_cos) / (ang_sin**2))
42
+
43
+ # project origin onto plane
44
+ if xyz_orig is None:
45
+ xyz_orig = torch.zeros(3)
46
+ t, _ = compute_plane_intersection(xyz_orig, -normal, ground_plane)
47
+
48
+ return R, t
49
+
50
+
51
+ def parse_floor_plane(floor_plane):
52
+ """
53
+ Takes floor plane in the optimization form (Bx3 with a,b,c * d) and parses into
54
+ (a,b,c,d) from with (a,b,c) normal facing "up in the camera frame and d the offset.
55
+ """
56
+ floor_offset = torch.norm(floor_plane, dim=-1, keepdim=True)
57
+ floor_normal = floor_plane / floor_offset
58
+
59
+ # in camera system -y is up, so floor plane normal y component should never be positive
60
+ # (assuming the camera is not sideways or upside down)
61
+ neg_mask = floor_normal[..., 1:2] > 0.0
62
+ floor_normal = torch.where(
63
+ neg_mask.expand_as(floor_normal), -floor_normal, floor_normal
64
+ )
65
+ floor_offset = torch.where(neg_mask, -floor_offset, floor_offset)
66
+ floor_plane_4d = torch.cat([floor_normal, floor_offset], dim=-1)
67
+
68
+ return floor_plane_4d
69
+
70
+
71
+ def compute_plane_intersection(point, direction, plane):
72
+ """
73
+ Given a ray defined by a point in space and a direction,
74
+ compute the intersection point with the given plane.
75
+ Detect intersection in either direction or -direction.
76
+ Note, ray may not actually intersect with the plane.
77
+
78
+ Returns the intersection point and s where
79
+ point + s * direction = intersection_point. if s < 0 it means
80
+ -direction intersects.
81
+
82
+ - point : B x 3
83
+ - direction : B x 3
84
+ - plane : B x 4 (a, b, c, d) where (a, b, c) is the normal and (d) the offset.
85
+ """
86
+ dims = point.shape[:-1]
87
+ plane_normal = plane[..., :3]
88
+ plane_off = plane[..., 3]
89
+ s = (plane_off - bdot(plane_normal, point)) / (bdot(plane_normal, direction) + 1e-4)
90
+ itsct_pt = point + s.reshape((-1, 1)) * direction
91
+ return itsct_pt, s
92
+
93
+
94
+ def bdot(A1, A2, keepdim=False):
95
+ """
96
+ Batched dot product.
97
+ - A1 : B x D
98
+ - A2 : B x D.
99
+ Returns B.
100
+ """
101
+ return (A1 * A2).sum(dim=-1, keepdim=keepdim)
slahmr/slahmr/geometry/rotation.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
7
+ """
8
+ Taken from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py
9
+ Calculates the rotation matrices for a batch of rotation vectors
10
+ - param rot_vecs: torch.tensor (N, 3) array of N axis-angle vectors
11
+ - returns R: torch.tensor (N, 3, 3) rotation matrices
12
+ """
13
+ batch_size = rot_vecs.shape[0]
14
+ device = rot_vecs.device
15
+
16
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
17
+ rot_dir = rot_vecs / angle
18
+
19
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
20
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
21
+
22
+ # Bx1 arrays
23
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
24
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
25
+
26
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
27
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
28
+ (batch_size, 3, 3)
29
+ )
30
+
31
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
32
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
33
+ return rot_mat
34
+
35
+
36
+ def quaternion_mul(q0, q1):
37
+ """
38
+ EXPECTS WXYZ
39
+ :param q0 (*, 4)
40
+ :param q1 (*, 4)
41
+ """
42
+ r0, r1 = q0[..., :1], q1[..., :1]
43
+ v0, v1 = q0[..., 1:], q1[..., 1:]
44
+ r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)
45
+ v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)
46
+ return torch.cat([r, v], dim=-1)
47
+
48
+
49
+ def quaternion_inverse(q, eps=1e-8):
50
+ """
51
+ EXPECTS WXYZ
52
+ :param q (*, 4)
53
+ """
54
+ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
55
+ mag = torch.square(q).sum(dim=-1, keepdim=True) + eps
56
+ return conj / mag
57
+
58
+
59
+ def quaternion_slerp(t, q0, q1, eps=1e-8):
60
+ """
61
+ :param t (*, 1) must be between 0 and 1
62
+ :param q0 (*, 4)
63
+ :param q1 (*, 4)
64
+ """
65
+ dims = q0.shape[:-1]
66
+ t = t.view(*dims, 1)
67
+
68
+ q0 = F.normalize(q0, p=2, dim=-1)
69
+ q1 = F.normalize(q1, p=2, dim=-1)
70
+ dot = (q0 * q1).sum(dim=-1, keepdim=True)
71
+
72
+ # make sure we give the shortest rotation path (< 180d)
73
+ neg = dot < 0
74
+ q1 = torch.where(neg, -q1, q1)
75
+ dot = torch.where(neg, -dot, dot)
76
+ angle = torch.acos(dot)
77
+
78
+ # if angle is too small, just do linear interpolation
79
+ collin = torch.abs(dot) > 1 - eps
80
+ fac = 1 / torch.sin(angle)
81
+ w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)
82
+ w1 = torch.where(collin, t, torch.sin(t * angle) * fac)
83
+ slerp = q0 * w0 + q1 * w1
84
+ return slerp
85
+
86
+
87
+ def rotation_matrix_to_angle_axis(rotation_matrix):
88
+ """
89
+ This function is borrowed from https://github.com/kornia/kornia
90
+
91
+ Convert rotation matrix to Rodrigues vector
92
+ """
93
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
94
+ aa = quaternion_to_angle_axis(quaternion)
95
+ aa[torch.isnan(aa)] = 0.0
96
+ return aa
97
+
98
+
99
+ def quaternion_to_angle_axis(quaternion):
100
+ """
101
+ This function is borrowed from https://github.com/kornia/kornia
102
+
103
+ Convert quaternion vector to angle axis of rotation.
104
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
105
+
106
+ :param quaternion (*, 4) expects WXYZ
107
+ :returns angle_axis (*, 3)
108
+ """
109
+ # unpack input and compute conversion
110
+ q1 = quaternion[..., 1]
111
+ q2 = quaternion[..., 2]
112
+ q3 = quaternion[..., 3]
113
+ sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3
114
+
115
+ sin_theta = torch.sqrt(sin_squared_theta)
116
+ cos_theta = quaternion[..., 0]
117
+ two_theta = 2.0 * torch.where(
118
+ cos_theta < 0.0,
119
+ torch.atan2(-sin_theta, -cos_theta),
120
+ torch.atan2(sin_theta, cos_theta),
121
+ )
122
+
123
+ k_pos = two_theta / sin_theta
124
+ k_neg = 2.0 * torch.ones_like(sin_theta)
125
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
126
+
127
+ angle_axis = torch.zeros_like(quaternion)[..., :3]
128
+ angle_axis[..., 0] += q1 * k
129
+ angle_axis[..., 1] += q2 * k
130
+ angle_axis[..., 2] += q3 * k
131
+ return angle_axis
132
+
133
+
134
+ def quaternion_to_rotation_matrix(quaternion):
135
+ """
136
+ Convert a quaternion to a rotation matrix.
137
+ Taken from https://github.com/kornia/kornia, based on
138
+ https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101
139
+ https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247
140
+ :param quaternion (N, 4) expects WXYZ order
141
+ returns rotation matrix (N, 3, 3)
142
+ """
143
+ # normalize the input quaternion
144
+ quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-12)
145
+ *dims, _ = quaternion_norm.shape
146
+
147
+ # unpack the normalized quaternion components
148
+ w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)
149
+
150
+ # compute the actual conversion
151
+ tx = 2.0 * x
152
+ ty = 2.0 * y
153
+ tz = 2.0 * z
154
+ twx = tx * w
155
+ twy = ty * w
156
+ twz = tz * w
157
+ txx = tx * x
158
+ txy = ty * x
159
+ txz = tz * x
160
+ tyy = ty * y
161
+ tyz = tz * y
162
+ tzz = tz * z
163
+ one = torch.tensor(1.0)
164
+
165
+ matrix = torch.stack(
166
+ (
167
+ one - (tyy + tzz),
168
+ txy - twz,
169
+ txz + twy,
170
+ txy + twz,
171
+ one - (txx + tzz),
172
+ tyz - twx,
173
+ txz - twy,
174
+ tyz + twx,
175
+ one - (txx + tyy),
176
+ ),
177
+ dim=-1,
178
+ ).view(*dims, 3, 3)
179
+ return matrix
180
+
181
+
182
+ def angle_axis_to_quaternion(angle_axis):
183
+ """
184
+ This function is borrowed from https://github.com/kornia/kornia
185
+ Convert angle axis to quaternion in WXYZ order
186
+ :param angle_axis (*, 3)
187
+ :returns quaternion (*, 4) WXYZ order
188
+ """
189
+ theta_sq = torch.sum(angle_axis**2, dim=-1, keepdim=True) # (*, 1)
190
+ # need to handle the zero rotation case
191
+ valid = theta_sq > 0
192
+ theta = torch.sqrt(theta_sq)
193
+ half_theta = 0.5 * theta
194
+ ones = torch.ones_like(half_theta)
195
+ # fill zero with the limit of sin ax / x -> a
196
+ k = torch.where(valid, torch.sin(half_theta) / theta, 0.5 * ones)
197
+ w = torch.where(valid, torch.cos(half_theta), ones)
198
+ quat = torch.cat([w, k * angle_axis], dim=-1)
199
+ return quat
200
+
201
+
202
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
203
+ """
204
+ This function is borrowed from https://github.com/kornia/kornia
205
+ Convert rotation matrix to 4d quaternion vector
206
+ This algorithm is based on algorithm described in
207
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
208
+
209
+ :param rotation_matrix (N, 3, 3)
210
+ """
211
+ *dims, m, n = rotation_matrix.shape
212
+ rmat_t = torch.transpose(rotation_matrix.reshape(-1, m, n), -1, -2)
213
+
214
+ mask_d2 = rmat_t[:, 2, 2] < eps
215
+
216
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
217
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
218
+
219
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
220
+ q0 = torch.stack(
221
+ [
222
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
223
+ t0,
224
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
225
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
226
+ ],
227
+ -1,
228
+ )
229
+ t0_rep = t0.repeat(4, 1).t()
230
+
231
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
232
+ q1 = torch.stack(
233
+ [
234
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
235
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
236
+ t1,
237
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
238
+ ],
239
+ -1,
240
+ )
241
+ t1_rep = t1.repeat(4, 1).t()
242
+
243
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
244
+ q2 = torch.stack(
245
+ [
246
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
247
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
248
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
249
+ t2,
250
+ ],
251
+ -1,
252
+ )
253
+ t2_rep = t2.repeat(4, 1).t()
254
+
255
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
256
+ q3 = torch.stack(
257
+ [
258
+ t3,
259
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
260
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
261
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
262
+ ],
263
+ -1,
264
+ )
265
+ t3_rep = t3.repeat(4, 1).t()
266
+
267
+ mask_c0 = mask_d2 * mask_d0_d1
268
+ mask_c1 = mask_d2 * ~mask_d0_d1
269
+ mask_c2 = ~mask_d2 * mask_d0_nd1
270
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
271
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
272
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
273
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
274
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
275
+
276
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
277
+ q /= torch.sqrt(
278
+ t0_rep * mask_c0
279
+ + t1_rep * mask_c1
280
+ + t2_rep * mask_c2 # noqa
281
+ + t3_rep * mask_c3
282
+ ) # noqa
283
+ q *= 0.5
284
+ return q.reshape(*dims, 4)
slahmr/slahmr/humor/__init__.py ADDED
File without changes
slahmr/slahmr/humor/amass_utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/davrempe/humor
3
+ """
4
+
5
+ from body_model.utils import SMPL_JOINTS
6
+
7
+
8
+ TRAIN_DATASETS = [
9
+ "CMU",
10
+ "MPI_Limits",
11
+ "TotalCapture",
12
+ "Eyes_Japan_Dataset",
13
+ "KIT",
14
+ "BioMotionLab_NTroje",
15
+ "BMLmovi",
16
+ "EKUT",
17
+ "ACCAD",
18
+ ]
19
+ TEST_DATASETS = ["Transitions_mocap", "HumanEva"]
20
+ VAL_DATASETS = ["MPI_HDM05", "SFU", "MPI_mosh"]
21
+
22
+
23
+ SPLITS = ["train", "val", "test", "custom"]
24
+ SPLIT_BY = [
25
+ "single", # the data path is a single .npz file. Don't split: train and test are same
26
+ "sequence", # the data paths are directories of subjects. Collate and split by sequence.
27
+ "subject", # the data paths are directories of datasets. Collate and split by subject.
28
+ "dataset", # a single data path to the amass data root is given. The predefined datasets will be used for each split.
29
+ ]
30
+
31
+ ROT_REPS = ["mat", "aa", "6d"]
32
+
33
+ # these correspond to [root, left knee, right knee, left heel, right heel, left toe, right toe, left hand, right hand]
34
+ CONTACT_ORDERING = [
35
+ "hips",
36
+ "leftLeg",
37
+ "rightLeg",
38
+ "leftFoot",
39
+ "rightFoot",
40
+ "leftToeBase",
41
+ "rightToeBase",
42
+ "leftHand",
43
+ "rightHand",
44
+ ]
45
+ CONTACT_INDS = [SMPL_JOINTS[jname] for jname in CONTACT_ORDERING]
46
+
47
+ NUM_BODY_JOINTS = len(SMPL_JOINTS) - 1
48
+ NUM_KEYPT_VERTS = 43
49
+
50
+ DATA_NAMES = [
51
+ "trans",
52
+ "trans_vel",
53
+ "root_orient",
54
+ "root_orient_vel",
55
+ "pose_body",
56
+ "pose_body_vel",
57
+ "joints",
58
+ "joints_vel",
59
+ "joints_orient_vel",
60
+ "verts",
61
+ "verts_vel",
62
+ "contacts",
63
+ ]
64
+
65
+ SMPL_JOINTS_RETURN_CONFIG = {
66
+ "trans": True,
67
+ "trans_vel": True,
68
+ "root_orient": True,
69
+ "root_orient_vel": True,
70
+ "pose_body": True,
71
+ "pose_body_vel": False,
72
+ "joints": True,
73
+ "joints_vel": True,
74
+ "joints_orient_vel": False,
75
+ "verts": False,
76
+ "verts_vel": False,
77
+ "contacts": False,
78
+ }
79
+
80
+ SMPL_JOINTS_CONTACTS_RETURN_CONFIG = {
81
+ "trans": True,
82
+ "trans_vel": True,
83
+ "root_orient": True,
84
+ "root_orient_vel": True,
85
+ "pose_body": True,
86
+ "pose_body_vel": False,
87
+ "joints": True,
88
+ "joints_vel": True,
89
+ "joints_orient_vel": False,
90
+ "verts": False,
91
+ "verts_vel": False,
92
+ "contacts": True,
93
+ }
94
+
95
+ ALL_RETURN_CONFIG = {
96
+ "trans": True,
97
+ "trans_vel": True,
98
+ "root_orient": True,
99
+ "root_orient_vel": True,
100
+ "pose_body": True,
101
+ "pose_body_vel": False,
102
+ "joints": True,
103
+ "joints_vel": True,
104
+ "joints_orient_vel": False,
105
+ "verts": True,
106
+ "verts_vel": False,
107
+ "contacts": True,
108
+ }
109
+
110
+ RETURN_CONFIGS = {
111
+ "smpl+joints+contacts": SMPL_JOINTS_CONTACTS_RETURN_CONFIG,
112
+ "smpl+joints": SMPL_JOINTS_RETURN_CONFIG,
113
+ "all": ALL_RETURN_CONFIG,
114
+ }
115
+
116
+
117
+ def data_name_list(return_config):
118
+ """
119
+ returns the list of data values in the given configuration
120
+ """
121
+ cur_ret_cfg = RETURN_CONFIGS[return_config]
122
+ data_names = [k for k in DATA_NAMES if cur_ret_cfg[k]]
123
+ return data_names
124
+
125
+
126
+ def data_dim(dname, rot_rep_size=9):
127
+ """
128
+ returns the dimension of the data with the given name. If the data is a rotation, returns the size with the given representation.
129
+ """
130
+ if dname in ["trans", "trans_vel", "root_orient_vel"]:
131
+ return 3
132
+ elif dname in ["root_orient"]:
133
+ return rot_rep_size
134
+ elif dname in ["pose_body"]:
135
+ return NUM_BODY_JOINTS * rot_rep_size
136
+ elif dname in ["pose_body_vel"]:
137
+ return NUM_BODY_JOINTS * 3
138
+ elif dname in ["joints", "joints_vel"]:
139
+ return len(SMPL_JOINTS) * 3
140
+ elif dname in ["joints_orient_vel"]:
141
+ return 1
142
+ elif dname in ["verts", "verts_vel"]:
143
+ return NUM_KEYPT_VERTS * 3
144
+ elif dname in ["contacts"]:
145
+ return len(CONTACT_ORDERING)
146
+ else:
147
+ print("The given data name %s is not valid!" % (dname))
148
+ exit()
slahmr/slahmr/humor/humor_model.py ADDED
@@ -0,0 +1,1655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/davrempe/humor
3
+ """
4
+
5
+ import time, os
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.distributions.normal import Normal
11
+
12
+ from .amass_utils import data_name_list, data_dim
13
+ from .transforms import (
14
+ convert_to_rotmat,
15
+ compute_world2aligned_mat,
16
+ rotation_matrix_to_angle_axis,
17
+ )
18
+
19
+ from body_model.specs import SMPL_JOINTS, SMPLH_PATH
20
+ from body_model.body_model import BodyModel
21
+
22
+
23
+ IN_ROT_REPS = ["aa", "6d", "mat"]
24
+ OUT_ROT_REPS = ["aa", "6d", "9d"]
25
+ ROT_REP_SIZE = {"aa": 3, "6d": 6, "mat": 9, "9d": 9}
26
+ NUM_SMPL_JOINTS = len(SMPL_JOINTS)
27
+ NUM_BODY_JOINTS = NUM_SMPL_JOINTS - 1 # no root
28
+ BETA_SIZE = 16
29
+
30
+ POSTERIOR_OPTIONS = ["mlp"]
31
+ PRIOR_OPTIONS = ["mlp"]
32
+ DECODER_OPTIONS = ["mlp"]
33
+
34
+ WORLD2ALIGN_NAME_CACHE = {
35
+ "root_orient": None,
36
+ "trans": None,
37
+ "joints": None,
38
+ "verts": None,
39
+ "joints_vel": None,
40
+ "verts_vel": None,
41
+ "trans_vel": None,
42
+ "root_orient_vel": None,
43
+ }
44
+
45
+
46
+ def step(
47
+ model, loss_func, data, dataset, device, cur_epoch, mode="train", use_gt_p=1.0
48
+ ):
49
+ """
50
+ Given data for the current training step (batch),
51
+ pulls out the necessary needed data,
52
+ runs the model,
53
+ calculates and returns the loss.
54
+
55
+ - use_gt_p : the probability of using ground truth as input to each step rather than the model's own prediction
56
+ (1.0 is fully supervised, 0.0 is fully autoregressive)
57
+ """
58
+ use_sched_samp = use_gt_p < 1.0
59
+ batch_in, batch_out, meta = data
60
+
61
+ prep_data = model.prepare_input(
62
+ batch_in,
63
+ device,
64
+ data_out=batch_out,
65
+ return_input_dict=True,
66
+ return_global_dict=use_sched_samp,
67
+ )
68
+ if use_sched_samp:
69
+ x_past, x_t, gt_dict, input_dict, global_gt_dict = prep_data
70
+ else:
71
+ x_past, x_t, gt_dict, input_dict = prep_data
72
+
73
+ B, T, S_in, _ = x_past.size()
74
+ S_out = x_t.size(2)
75
+
76
+ if not use_sched_samp:
77
+ # fully supervised phase
78
+ # start by using gt at every step, so just form all steps from all sequences into one large batch
79
+ # and get per-step predictions
80
+ x_past_batched = x_past.reshape((B * T, S_in, -1))
81
+ x_t_batched = x_t.reshape((B * T, S_out, -1))
82
+ out_dict = model(x_past_batched, x_t_batched)
83
+ else:
84
+ # in scheduled sampling or fully autoregressive phase
85
+ init_input_dict = dict()
86
+ for k in input_dict.keys():
87
+ init_input_dict[k] = input_dict[k][
88
+ :, 0, :, :
89
+ ] # only need first step for init
90
+ # this out_dict is the global state
91
+ sched_samp_out = model.scheduled_sampling(
92
+ x_past,
93
+ x_t,
94
+ init_input_dict,
95
+ p=use_gt_p,
96
+ gender=meta["gender"],
97
+ betas=meta["betas"].to(device),
98
+ need_global_out=(not model.detach_sched_samp),
99
+ )
100
+ if model.detach_sched_samp:
101
+ out_dict = sched_samp_out
102
+ else:
103
+ out_dict, _ = sched_samp_out
104
+ # gt must be global state for supervision in this case
105
+ if not model.detach_sched_samp:
106
+ print("USING global supervision")
107
+ gt_dict = global_gt_dict
108
+
109
+ # loss can be computed per output step in parallel
110
+ # batch dicts accordingly
111
+ for k in out_dict.keys():
112
+ if k == "posterior_distrib" or k == "prior_distrib":
113
+ m, v = out_dict[k]
114
+ m = m.reshape((B * T, -1))
115
+ v = v.reshape((B * T, -1))
116
+ out_dict[k] = (m, v)
117
+ else:
118
+ out_dict[k] = out_dict[k].reshape((B * T * S_out, -1))
119
+ for k in gt_dict.keys():
120
+ gt_dict[k] = gt_dict[k].reshape((B * T * S_out, -1))
121
+
122
+ gender_in = np.broadcast_to(
123
+ np.array(meta["gender"]).reshape((B, 1, 1, 1)), (B, T, S_out, 1)
124
+ )
125
+ gender_in = gender_in.reshape((B * T * S_out, 1))
126
+ betas_in = meta["betas"].reshape((B, T, 1, -1)).expand((B, T, S_out, 16)).to(device)
127
+ betas_in = betas_in.reshape((B * T * S_out, 16))
128
+ loss, stats_dict = loss_func(
129
+ out_dict, gt_dict, cur_epoch, gender=gender_in, betas=betas_in
130
+ )
131
+
132
+ return loss, stats_dict
133
+
134
+
135
+ class HumorModel(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_rot_rep="aa",
139
+ out_rot_rep="aa",
140
+ latent_size=48,
141
+ steps_in=1,
142
+ conditional_prior=True, # use a learned prior rather than standard normal
143
+ output_delta=True, # output change in state from decoder rather than next step directly
144
+ posterior_arch="mlp",
145
+ decoder_arch="mlp",
146
+ prior_arch="mlp",
147
+ model_data_config="smpl+joints+contacts",
148
+ detach_sched_samp=True, # if true, detaches outputs of previous step so gradients don't flow through many steps
149
+ model_use_smpl_joint_inputs=False, # if true, uses smpl joints rather than regressed joints to input at next step (during rollout and sched samp)
150
+ model_smpl_batch_size=1, # if using smpl joint inputs this should be batch_size of the smpl model (aka data input to rollout)
151
+ ):
152
+ super(HumorModel, self).__init__()
153
+ self.ignore_keys = []
154
+
155
+ self.steps_in = steps_in
156
+ self.steps_out = 1
157
+ self.out_step_size = 1
158
+ self.detach_sched_samp = detach_sched_samp
159
+ self.output_delta = output_delta
160
+
161
+ if self.steps_out > 1:
162
+ raise NotImplementedError("Only supported single step output currently.")
163
+
164
+ if out_rot_rep not in OUT_ROT_REPS:
165
+ raise Exception(
166
+ "Not a valid output rotation representation: %s" % (out_rot_rep)
167
+ )
168
+ if in_rot_rep not in IN_ROT_REPS:
169
+ raise Exception(
170
+ "Not a valid input rotation representation: %s" % (in_rot_rep)
171
+ )
172
+ self.out_rot_rep = out_rot_rep
173
+ self.in_rot_rep = in_rot_rep
174
+
175
+ if posterior_arch not in POSTERIOR_OPTIONS:
176
+ raise Exception("Not a valid encoder architecture: %s" % (posterior_arch))
177
+ if decoder_arch not in DECODER_OPTIONS:
178
+ raise Exception("Not a valid decoder architecture: %s" % (decoder_arch))
179
+ if conditional_prior and prior_arch not in PRIOR_OPTIONS:
180
+ raise Exception("Not a valid prior architecture: %s" % (prior_arch))
181
+ self.posterior_arch = posterior_arch
182
+ self.decoder_arch = decoder_arch
183
+ self.prior_arch = prior_arch
184
+
185
+ # get the list of data names for this config
186
+ self.data_names = data_name_list(model_data_config)
187
+ self.aux_in_data_names = (
188
+ self.aux_out_data_names
189
+ ) = None # auxiliary data will be returned as part of the input/output dictionary, but not the actual network input/output tensor
190
+ self.pred_contacts = False
191
+ if (
192
+ model_data_config.find("contacts") >= 0
193
+ ): # network is outputting contact classification as well and need to supervise, but not given as input to net.
194
+ self.data_names.remove("contacts")
195
+ self.aux_out_data_names = ["contacts"]
196
+ self.pred_contacts = True
197
+
198
+ self.need_trans2joint = (
199
+ "joints" in self.data_names or "verts" in self.data_names
200
+ )
201
+ self.model_data_config = model_data_config
202
+
203
+ self.input_rot_dim = ROT_REP_SIZE[self.in_rot_rep]
204
+ self.input_dim_list = [
205
+ data_dim(dname, rot_rep_size=self.input_rot_dim)
206
+ for dname in self.data_names
207
+ ]
208
+ self.input_data_dim = sum(self.input_dim_list)
209
+
210
+ self.output_rot_dim = ROT_REP_SIZE[self.out_rot_rep]
211
+ self.output_dim_list = [
212
+ data_dim(dname, rot_rep_size=self.output_rot_dim)
213
+ for dname in self.data_names
214
+ ]
215
+ self.delta_output_dim_list = [
216
+ data_dim(dname, rot_rep_size=ROT_REP_SIZE["mat"])
217
+ for dname in self.data_names
218
+ ]
219
+
220
+ if self.pred_contacts:
221
+ # account for contact classification output
222
+ self.output_dim_list.append(data_dim("contacts"))
223
+ self.delta_output_dim_list.append(data_dim("contacts"))
224
+
225
+ self.output_data_dim = sum(self.output_dim_list)
226
+
227
+ self.latent_size = latent_size
228
+ past_data_dim = self.steps_in * self.input_data_dim
229
+ t_data_dim = self.steps_out * self.input_data_dim
230
+
231
+ # posterior encoder (given past and future, predict latent transition distribution)
232
+ print("Using posterior architecture: %s" % (self.posterior_arch))
233
+ if self.posterior_arch == "mlp":
234
+ layer_list = [
235
+ past_data_dim + t_data_dim,
236
+ 1024,
237
+ 1024,
238
+ 1024,
239
+ 1024,
240
+ self.latent_size * 2,
241
+ ]
242
+ self.encoder = MLP(
243
+ layers=layer_list, # mu and sigma output
244
+ nonlinearity=nn.ReLU,
245
+ use_gn=True,
246
+ )
247
+
248
+ # decoder (given past and latent transition, predict future) for the immediate next step
249
+ print("Using decoder architecture: %s" % (self.decoder_arch))
250
+ decoder_input_dim = past_data_dim + self.latent_size
251
+ if self.decoder_arch == "mlp":
252
+ layer_list = [decoder_input_dim, 1024, 1024, 512, self.output_data_dim]
253
+ self.decoder = MLP(
254
+ layers=layer_list,
255
+ nonlinearity=nn.ReLU,
256
+ use_gn=True,
257
+ skip_input_idx=past_data_dim, # skip connect the latent to every layer
258
+ )
259
+
260
+ # prior (if conditional, given past predict latent transition distribution)
261
+ self.use_conditional_prior = conditional_prior
262
+ if self.use_conditional_prior:
263
+ print("Using prior architecture: %s" % (self.prior_arch))
264
+ layer_list = [past_data_dim, 1024, 1024, 1024, 1024, self.latent_size * 2]
265
+ self.prior_net = MLP(
266
+ layers=layer_list, # mu and sigma output
267
+ nonlinearity=nn.ReLU,
268
+ use_gn=True,
269
+ )
270
+ else:
271
+ print("Using standard normal prior.")
272
+
273
+ self.use_smpl_joint_inputs = model_use_smpl_joint_inputs
274
+ self.smpl_batch_size = model_smpl_batch_size
275
+ if self.use_smpl_joint_inputs:
276
+ # need a body model to compute the joints after each step.
277
+ print(
278
+ "Using SMPL joints rather than regressed joints as input at each step for roll out and scheduled sampling..."
279
+ )
280
+ male_bm_path = os.path.join(SMPLH_PATH, "male/model.npz")
281
+ self.male_bm = BodyModel(
282
+ bm_path=male_bm_path, num_betas=16, batch_size=self.smpl_batch_size
283
+ )
284
+ female_bm_path = os.path.join(SMPLH_PATH, "female/model.npz")
285
+ self.female_bm = BodyModel(
286
+ bm_path=female_bm_path, num_betas=16, batch_size=self.smpl_batch_size
287
+ )
288
+ neutral_bm_path = os.path.join(SMPLH_PATH, "neutral/model.npz")
289
+ self.neutral_bm = BodyModel(
290
+ bm_path=neutral_bm_path, num_betas=16, batch_size=self.smpl_batch_size
291
+ )
292
+ self.bm_dict = {
293
+ "male": self.male_bm,
294
+ "female": self.female_bm,
295
+ "neutral": self.neutral_bm,
296
+ }
297
+ for p in self.male_bm.parameters():
298
+ p.requires_grad = False
299
+ for p in self.female_bm.parameters():
300
+ p.requires_grad = False
301
+ for p in self.neutral_bm.parameters():
302
+ p.requires_grad = False
303
+ self.ignore_keys = ["male_bm", "female_bm", "neutral_bm"]
304
+
305
+ def prepare_input(
306
+ self,
307
+ data_in,
308
+ device,
309
+ data_out=None,
310
+ return_input_dict=False,
311
+ return_global_dict=False,
312
+ ):
313
+ """
314
+ Concatenates input and output data as expected by the model.
315
+
316
+ Also creates a dictionary of GT outputs for use in computing the loss. And optionally
317
+ a dictionary of inputs.
318
+ """
319
+
320
+ #
321
+ # input data
322
+ #
323
+ in_unnorm_data_list = []
324
+ for k in self.data_names:
325
+ cur_dat = data_in[k].to(device)
326
+ B, T = cur_dat.size(0), cur_dat.size(1)
327
+ cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_in, -1))
328
+ in_unnorm_data_list.append(cur_unnorm_dat)
329
+ x_past = torch.cat(in_unnorm_data_list, axis=3)
330
+
331
+ input_dict = None
332
+ if return_input_dict:
333
+ input_dict = {k: v for k, v in zip(self.data_names, in_unnorm_data_list)}
334
+
335
+ if self.aux_in_data_names is not None:
336
+ for k in self.aux_in_data_names:
337
+ cur_dat = data_in[k].to(device)
338
+ B, T = cur_dat.size(0), cur_dat.size(1)
339
+ cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_in, -1))
340
+ input_dict[k] = cur_unnorm_dat
341
+
342
+ #
343
+ # output
344
+ #
345
+ if data_out is not None:
346
+ out_unnorm_data_list = []
347
+ for k in self.data_names:
348
+ cur_dat = data_out[k].to(device)
349
+ B, T = cur_dat.size(0), cur_dat.size(1)
350
+ cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_out, -1))
351
+ out_unnorm_data_list.append(cur_unnorm_dat)
352
+ x_t = torch.cat(out_unnorm_data_list, axis=3)
353
+ gt_dict = {k: v for k, v in zip(self.data_names, out_unnorm_data_list)}
354
+
355
+ if self.aux_out_data_names is not None:
356
+ for k in self.aux_out_data_names:
357
+ cur_dat = data_out[k].to(device)
358
+ B, T = cur_dat.size(0), cur_dat.size(1)
359
+ cur_unnorm_dat = cur_dat.reshape((B, T, self.steps_out, -1))
360
+ gt_dict[k] = cur_unnorm_dat
361
+
362
+ return_list = [x_past, x_t, gt_dict]
363
+ if return_input_dict:
364
+ return_list.append(input_dict)
365
+
366
+ #
367
+ # global
368
+ #
369
+ if return_global_dict:
370
+ global_gt_dict = dict()
371
+ for k in self.data_names:
372
+ global_k = "global_" + k
373
+ cur_dat = data_out[global_k].to(device)
374
+ B, T = cur_dat.size(0), cur_dat.size(1)
375
+ # expand each to have steps_out since originally they are just B x T x ... x D
376
+ cur_dat = cur_dat.reshape((B, T, 1, -1)).expand_as(gt_dict[k])
377
+ global_gt_dict[k] = cur_dat
378
+
379
+ if self.aux_out_data_names is not None:
380
+ for k in self.aux_out_data_names:
381
+ global_k = "global_" + k
382
+ cur_dat = data_out[global_k].to(device)
383
+ B, T = cur_dat.size(0), cur_dat.size(1)
384
+ # expand each to have steps_out since originally they are just B x T x ... x D
385
+ cur_dat = cur_dat.reshape((B, T, 1, -1)).expand_as(gt_dict[k])
386
+ global_gt_dict[k] = cur_dat
387
+
388
+ return_list.append(global_gt_dict)
389
+
390
+ return tuple(return_list)
391
+
392
+ else:
393
+ if return_input_dict:
394
+ return x_past, input_dict
395
+ else:
396
+ return x_past
397
+
398
+ def split_output(self, decoder_out, convert_rots=True):
399
+ """
400
+ Given the output of the decoder, splits into each state component.
401
+ Also transform rotation representation to matrices.
402
+
403
+ Input:
404
+ - decoder_out (B x steps_out x D)
405
+
406
+ Returns:
407
+ - output dict
408
+ """
409
+ B = decoder_out.size(0)
410
+ decoder_out = decoder_out.reshape((B, self.steps_out, -1))
411
+
412
+ # collect outputs
413
+ name_list = self.data_names
414
+ if self.aux_out_data_names is not None:
415
+ name_list = name_list + self.aux_out_data_names
416
+ idx_list = (
417
+ self.delta_output_dim_list if self.output_delta else self.output_dim_list
418
+ )
419
+ out_dict = dict()
420
+ sidx = 0
421
+ for cur_name, cur_idx in zip(name_list, idx_list):
422
+ eidx = sidx + cur_idx
423
+ out_dict[cur_name] = decoder_out[:, :, sidx:eidx]
424
+ sidx = eidx
425
+
426
+ # transform rotations
427
+ if convert_rots and not self.output_delta: # output delta already gives rotmats
428
+ if "root_orient" in self.data_names:
429
+ out_dict["root_orient"] = convert_to_rotmat(
430
+ out_dict["root_orient"], rep=self.out_rot_rep
431
+ )
432
+ if "pose_body" in self.data_names:
433
+ out_dict["pose_body"] = convert_to_rotmat(
434
+ out_dict["pose_body"], rep=self.out_rot_rep
435
+ )
436
+
437
+ return out_dict
438
+
439
+ def forward(self, x_past, x_t):
440
+ """
441
+ single step full forward pass. This uses the posterior for sampling, not the prior.
442
+
443
+ Input:
444
+ - x_past (B x steps_in x D)
445
+ - x_t (B x steps_out x D)
446
+
447
+ Returns dict of:
448
+ - x_pred (B x steps_out x D)
449
+ - posterior_distrib (Normal(mu, sigma))
450
+ - prior_distrib (Normal(mu, sigma))
451
+ """
452
+
453
+ B, _, D = x_past.size()
454
+ past_in = x_past.reshape((B, -1))
455
+ t_in = x_t.reshape((B, -1))
456
+
457
+ x_pred_dict = self.single_step(past_in, t_in)
458
+
459
+ return x_pred_dict
460
+
461
+ def single_step(self, past_in, t_in):
462
+ """
463
+ single step that computes both prior and posterior for training. Samples from posterior
464
+ """
465
+ B = past_in.size(0)
466
+ # use past and future to encode latent transition
467
+ qm, qv = self.posterior(past_in, t_in)
468
+
469
+ # prior
470
+ pm, pv = None, None
471
+ if self.use_conditional_prior:
472
+ # predict prior based on past
473
+ pm, pv = self.prior(past_in)
474
+ else:
475
+ # use standard normal
476
+ pm, pv = torch.zeros_like(qm), torch.ones_like(qv)
477
+
478
+ # sample from posterior using reparam trick
479
+ z = self.rsample(qm, qv)
480
+
481
+ # decode to get next step
482
+ decoder_out = self.decode(z, past_in)
483
+ decoder_out = decoder_out.reshape(
484
+ (B, self.steps_out, -1)
485
+ ) # B x steps_out x D_out
486
+
487
+ # split output predictions and transform out rotations to matrices
488
+ x_pred_dict = self.split_output(decoder_out)
489
+
490
+ x_pred_dict["posterior_distrib"] = (qm, qv)
491
+ x_pred_dict["prior_distrib"] = (pm, pv)
492
+
493
+ return x_pred_dict
494
+
495
+ def prior(self, past_in):
496
+ """
497
+ Encodes the posterior distribution using the past and future states.
498
+
499
+ Input:
500
+ - past_in (B x steps_in*D)
501
+ """
502
+ prior_out = self.prior_net(past_in)
503
+ mean = prior_out[:, : self.latent_size]
504
+ logvar = prior_out[:, self.latent_size :]
505
+ var = torch.exp(logvar)
506
+ return mean, var
507
+
508
+ def posterior(self, past_in, t_in):
509
+ """
510
+ Encodes the posterior distribution using the past and future states.
511
+
512
+ Input:
513
+ - past_in (B x steps_in*D)
514
+ - t_in (B x steps_out*D)
515
+ """
516
+ encoder_in = torch.cat([past_in, t_in], axis=1)
517
+
518
+ encoder_out = self.encoder(encoder_in)
519
+ mean = encoder_out[:, : self.latent_size]
520
+ logvar = encoder_out[:, self.latent_size :]
521
+ var = torch.exp(logvar)
522
+
523
+ return mean, var
524
+
525
+ def rsample(self, mu, var):
526
+ """
527
+ Return gaussian sample of (mu, var) using reparameterization trick.
528
+ """
529
+ eps = torch.randn_like(mu)
530
+ z = mu + eps * torch.sqrt(var)
531
+ return z
532
+
533
+ def decode(self, z, past_in):
534
+ """
535
+ Decodes prediction from the latent transition and past states
536
+
537
+ Input:
538
+ - z (B x latent_size)
539
+ - past_in (B x steps_in*D)
540
+
541
+ Returns:
542
+ - decoder_out (B x steps_out*D)
543
+ """
544
+ B = z.size(0)
545
+ decoder_in = torch.cat([past_in, z], axis=1)
546
+ decoder_out = self.decoder(decoder_in).reshape((B, 1, -1))
547
+
548
+ if self.output_delta:
549
+ # network output is the residual, add to the input to get final output
550
+ step_in = past_in.reshape((B, self.steps_in, -1))[
551
+ :, -1:, :
552
+ ] # most recent input step
553
+
554
+ final_out_list = []
555
+ in_sidx = out_sidx = 0
556
+ decode_out_dim_list = self.output_dim_list
557
+ if self.pred_contacts:
558
+ decode_out_dim_list = decode_out_dim_list[:-1] # do contacts separately
559
+ for in_dim_idx, out_dim_idx, data_name in zip(
560
+ self.input_dim_list, decode_out_dim_list, self.data_names
561
+ ):
562
+ in_eidx = in_sidx + in_dim_idx
563
+ out_eidx = out_sidx + out_dim_idx
564
+
565
+ # add residual to input (and transform as necessary for rotations)
566
+ in_val = step_in[:, :, in_sidx:in_eidx]
567
+ out_val = decoder_out[:, :, out_sidx:out_eidx]
568
+ if data_name in ["root_orient", "pose_body"]:
569
+ if self.in_rot_rep != "mat":
570
+ in_val = convert_to_rotmat(in_val, rep=self.in_rot_rep)
571
+ out_val = convert_to_rotmat(out_val, rep=self.out_rot_rep)
572
+
573
+ in_val = in_val.reshape((B, 1, -1, 3, 3))
574
+ out_val = out_val.reshape((B, self.steps_out, -1, 3, 3))
575
+
576
+ rot_in = torch.matmul(out_val, in_val).reshape(
577
+ (B, self.steps_out, -1)
578
+ ) # rotate by predicted residual
579
+ final_out_list.append(rot_in)
580
+ else:
581
+ final_out_list.append(out_val + in_val)
582
+
583
+ in_sidx = in_eidx
584
+ out_sidx = out_eidx
585
+ if self.pred_contacts:
586
+ final_out_list.append(decoder_out[:, :, out_sidx:])
587
+
588
+ decoder_out = torch.cat(final_out_list, dim=2)
589
+
590
+ decoder_out = decoder_out.reshape((B, -1))
591
+
592
+ return decoder_out
593
+
594
+ def scheduled_sampling(
595
+ self,
596
+ x_past,
597
+ x_t,
598
+ init_input_dict,
599
+ p=0.5,
600
+ gender=None,
601
+ betas=None,
602
+ need_global_out=True,
603
+ ):
604
+ """
605
+ Given all inputs and ground truth outputs for all steps, roll out model predictions
606
+ where at each step use the GT input with prob p, otherwise use own previous output.
607
+
608
+ Input:
609
+ - x_past (B x T x steps_in x D)
610
+ - x_t (B x T x steps_out x D)
611
+ - init_input_dict : dictionary of each initial state (B x steps_in x D), rotations should be matrices
612
+ - p : probability of using the GT input at each step of the sequence
613
+ - gender/betas only required if self.use_smpl_joint_inputs is true (used to decide the SMPL body model)
614
+ """
615
+ B, T, S, D = x_past.size()
616
+ S_out = x_t.size(2)
617
+ J = len(SMPL_JOINTS)
618
+ cur_input_dict = init_input_dict # this is the predicted input dict
619
+
620
+ # initial input must be from GT since we don't have any predictions yet
621
+ past_in = x_past[:, 0, :, :].reshape((B, -1))
622
+ t_in = x_t[:, 0, :, :].reshape((B, -1))
623
+
624
+ global_world2local_rot = (
625
+ torch.eye(3).reshape((1, 1, 3, 3)).expand((B, 1, 3, 3)).to(x_past)
626
+ )
627
+ global_world2local_trans = torch.zeros((B, 1, 3)).to(x_past)
628
+ trans2joint = torch.zeros((B, 1, 1, 3)).to(x_past)
629
+ if self.need_trans2joint:
630
+ trans2joint = -torch.cat(
631
+ [cur_input_dict["joints"][:, -1, :2], torch.zeros((B, 1)).to(x_past)],
632
+ axis=1,
633
+ ).reshape(
634
+ (B, 1, 1, 3)
635
+ ) # same for whole sequence
636
+ pred_local_seq = []
637
+ pred_global_seq = []
638
+ for t in range(T):
639
+ # sample next step from model
640
+ x_pred_dict = self.single_step(past_in, t_in)
641
+
642
+ # save output
643
+ pred_local_seq.append(x_pred_dict)
644
+
645
+ # output is the actual regressed joints, but input to next step can use smpl joints
646
+ x_pred_smpl_joints = None
647
+ if self.use_smpl_joint_inputs and gender is not None and betas is not None:
648
+ # this assumes the model is actually outputting everything we need to run SMPL
649
+ # also assumes single output step
650
+ smpl_trans = x_pred_dict["trans"][:, 0:1].reshape(
651
+ (B, 3)
652
+ ) # only want immediate next frame
653
+ smpl_root_orient = rotation_matrix_to_angle_axis(
654
+ x_pred_dict["root_orient"][:, 0:1].reshape((B, 3, 3))
655
+ ).reshape((B, 3))
656
+ smpl_betas = betas[:, 0, :]
657
+ smpl_pose_body = rotation_matrix_to_angle_axis(
658
+ x_pred_dict["pose_body"][:, 0:1].reshape((B * (J - 1), 3, 3))
659
+ ).reshape((B, (J - 1) * 3))
660
+
661
+ smpl_vals = [smpl_trans, smpl_root_orient, smpl_betas, smpl_pose_body]
662
+ # batch may be a mix of genders, so need to carefully use the corresponding SMPL body model
663
+ gender_names = ["male", "female", "neutral"]
664
+ pred_joints = []
665
+ prev_nbidx = 0
666
+ cat_idx_map = np.ones((B), dtype=np.int) * -1
667
+ for gender_name in gender_names:
668
+ gender_idx = np.array(gender) == gender_name
669
+ nbidx = np.sum(gender_idx)
670
+
671
+ cat_idx_map[gender_idx] = np.arange(
672
+ prev_nbidx, prev_nbidx + nbidx, dtype=np.int
673
+ )
674
+ prev_nbidx += nbidx
675
+
676
+ gender_smpl_vals = [val[gender_idx] for val in smpl_vals]
677
+
678
+ # need to pad extra frames with zeros in case not as long as expected
679
+ pad_size = self.smpl_batch_size - nbidx
680
+ if pad_size == B:
681
+ # skip if no frames for this gender
682
+ continue
683
+ pad_list = gender_smpl_vals
684
+ if pad_size < 0:
685
+ raise Exception(
686
+ "SMPL model batch size not large enough to accomodate!"
687
+ )
688
+ elif pad_size > 0:
689
+ pad_list = self.zero_pad_tensors(pad_list, pad_size)
690
+
691
+ # reconstruct SMPL
692
+ cur_pred_trans, cur_pred_orient, cur_betas, cur_pred_pose = pad_list
693
+ bm = self.bm_dict[gender_name]
694
+ pred_body = bm(
695
+ pose_body=cur_pred_pose,
696
+ betas=cur_betas,
697
+ root_orient=cur_pred_orient,
698
+ trans=cur_pred_trans,
699
+ )
700
+ if pad_size > 0:
701
+ pred_joints.append(pred_body.Jtr[:-pad_size])
702
+ else:
703
+ pred_joints.append(pred_body.Jtr)
704
+
705
+ # cat all genders and reorder to original batch ordering
706
+ x_pred_smpl_joints = torch.cat(pred_joints, axis=0)[
707
+ :, : len(SMPL_JOINTS), :
708
+ ].reshape((B, 1, -1))
709
+ x_pred_smpl_joints = x_pred_smpl_joints[cat_idx_map]
710
+
711
+ # prepare predicted input to next step in case needed
712
+ # update input dict with new frame
713
+ del_keys = []
714
+ for k in cur_input_dict.keys():
715
+ if k in x_pred_dict:
716
+ # drop oldest frame and add new prediction
717
+ keep_frames = cur_input_dict[k][:, 1:, :]
718
+ # print(keep_frames.size())
719
+ if (
720
+ k == "joints"
721
+ and self.use_smpl_joint_inputs
722
+ and x_pred_smpl_joints is not None
723
+ ):
724
+ # print('Using SMPL joints rather than regressed joints...')
725
+ if self.detach_sched_samp:
726
+ cur_input_dict[k] = torch.cat(
727
+ [keep_frames, x_pred_smpl_joints.detach()], axis=1
728
+ )
729
+ else:
730
+ cur_input_dict[k] = torch.cat(
731
+ [keep_frames, x_pred_smpl_joints], axis=1
732
+ )
733
+ else:
734
+ if self.detach_sched_samp:
735
+ cur_input_dict[k] = torch.cat(
736
+ [keep_frames, x_pred_dict[k][:, 0:1, :].detach()],
737
+ axis=1,
738
+ )
739
+ else:
740
+ cur_input_dict[k] = torch.cat(
741
+ [keep_frames, x_pred_dict[k][:, 0:1, :]], axis=1
742
+ )
743
+ # print(cur_input_dict[k].size())
744
+ else:
745
+ del_keys.append(k)
746
+ for k in del_keys:
747
+ del cur_input_dict[k] # don't need it anymore
748
+
749
+ # get world2aligned rot and translation
750
+ if self.detach_sched_samp:
751
+ root_orient_mat = (
752
+ x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3)).detach()
753
+ )
754
+ world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
755
+ world2aligned_trans = torch.cat(
756
+ [
757
+ -x_pred_dict["trans"][:, 0, :2].detach(),
758
+ torch.zeros((B, 1)).to(x_past),
759
+ ],
760
+ axis=1,
761
+ )
762
+ else:
763
+ root_orient_mat = x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3))
764
+ world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
765
+ world2aligned_trans = torch.cat(
766
+ [-x_pred_dict["trans"][:, 0, :2], torch.zeros((B, 1)).to(x_past)],
767
+ axis=1,
768
+ )
769
+
770
+ #
771
+ # transform inputs to this local frame for next step
772
+ #
773
+ cur_input_dict = self.apply_world2local_trans(
774
+ world2aligned_trans,
775
+ world2aligned_rot,
776
+ trans2joint,
777
+ cur_input_dict,
778
+ cur_input_dict,
779
+ invert=False,
780
+ )
781
+
782
+ # convert rots to correct input format
783
+ if self.in_rot_rep == "aa":
784
+ if "root_orient" in self.data_names:
785
+ cur_input_dict["root_orient"] = rotation_matrix_to_angle_axis(
786
+ cur_input_dict["root_orient"].reshape((B * S, 3, 3))
787
+ ).reshape((B, S, 3))
788
+ if "pose_body" in self.data_names:
789
+ cur_input_dict["pose_body"] = rotation_matrix_to_angle_axis(
790
+ cur_input_dict["pose_body"].reshape((B * S * (J - 1), 3, 3))
791
+ ).reshape((B, S, (J - 1) * 3))
792
+ elif self.in_rot_rep == "6d":
793
+ if "root_orient" in self.data_names:
794
+ cur_input_dict["root_orient"] = cur_input_dict["root_orient"][
795
+ :, :, :6
796
+ ]
797
+ if "pose_body" in self.data_names:
798
+ cur_input_dict["pose_body"] = (
799
+ cur_input_dict["pose_body"]
800
+ .reshape((B, S, J - 1, 9))[:, :, :, :6]
801
+ .reshape((B, S, (J - 1) * 6))
802
+ )
803
+
804
+ if need_global_out:
805
+ #
806
+ # compute current world output and update world2local transform
807
+ #
808
+ cur_world_dict = dict()
809
+ cur_world_dict = self.apply_world2local_trans(
810
+ global_world2local_trans,
811
+ global_world2local_rot,
812
+ trans2joint,
813
+ x_pred_dict,
814
+ cur_world_dict,
815
+ invert=True,
816
+ )
817
+
818
+ if self.detach_sched_samp:
819
+ global_world2local_trans = torch.cat(
820
+ [
821
+ -cur_world_dict["trans"][:, 0:1, :2].detach(),
822
+ torch.zeros((B, 1, 1)).to(x_past),
823
+ ],
824
+ axis=2,
825
+ )
826
+ else:
827
+ global_world2local_trans = torch.cat(
828
+ [
829
+ -cur_world_dict["trans"][:, 0:1, :2],
830
+ torch.zeros((B, 1, 1)).to(x_past),
831
+ ],
832
+ axis=2,
833
+ )
834
+
835
+ global_world2local_rot = torch.matmul(
836
+ global_world2local_rot, world2aligned_rot.reshape((B, 1, 3, 3))
837
+ )
838
+
839
+ pred_global_seq.append(cur_world_dict)
840
+
841
+ if t + 1 < T:
842
+ # choose whether next step will use GT or predicted inputs and prepare them
843
+ if np.random.random_sample() < p:
844
+ # use GT
845
+ past_in = x_past[:, t + 1, :, :].reshape((B, -1))
846
+ else:
847
+ # cat all inputs together to form past_in
848
+ in_data_list = []
849
+ for k in self.data_names:
850
+ in_data_list.append(cur_input_dict[k])
851
+ past_in = torch.cat(in_data_list, axis=2)
852
+ past_in = past_in.reshape((B, -1))
853
+
854
+ # GT output is the same no matter what
855
+ t_in = x_t[:, t + 1, :, :].reshape((B, -1))
856
+
857
+ if need_global_out:
858
+ # aggregate pred_seq
859
+ pred_global_seq_out = dict()
860
+ for k in pred_global_seq[0].keys():
861
+ if k == "posterior_distrib" or k == "prior_distrib":
862
+ m = torch.stack(
863
+ [pred_global_seq[i][k][0] for i in range(len(pred_global_seq))],
864
+ axis=1,
865
+ )
866
+ v = torch.stack(
867
+ [pred_global_seq[i][k][1] for i in range(len(pred_global_seq))],
868
+ axis=1,
869
+ )
870
+ pred_global_seq_out[k] = (m, v)
871
+ else:
872
+ pred_global_seq_out[k] = torch.stack(
873
+ [pred_global_seq[i][k] for i in range(len(pred_global_seq))],
874
+ axis=1,
875
+ )
876
+
877
+ # aggregate pred_seq
878
+ pred_local_seq_out = dict()
879
+ for k in pred_local_seq[0].keys():
880
+ # print(k)
881
+ if k == "posterior_distrib" or k == "prior_distrib":
882
+ m = torch.stack(
883
+ [pred_local_seq[i][k][0] for i in range(len(pred_local_seq))],
884
+ axis=1,
885
+ )
886
+ v = torch.stack(
887
+ [pred_local_seq[i][k][1] for i in range(len(pred_local_seq))],
888
+ axis=1,
889
+ )
890
+ pred_local_seq_out[k] = (m, v)
891
+ else:
892
+ pred_local_seq_out[k] = torch.stack(
893
+ [pred_local_seq[i][k] for i in range(len(pred_local_seq))], axis=1
894
+ )
895
+
896
+ if need_global_out:
897
+ return pred_global_seq_out, pred_local_seq_out
898
+ else:
899
+ return pred_local_seq_out
900
+
901
+ def apply_world2local_trans(
902
+ self,
903
+ world2local_trans,
904
+ world2local_rot,
905
+ trans2joint,
906
+ input_dict,
907
+ output_dict,
908
+ invert=False,
909
+ ):
910
+ """
911
+ Applies the given world2local transformation to the data in input_dict and stores the result in output_dict.
912
+
913
+ If invert is true, applies local2world.
914
+
915
+ - world2local_trans : B x 3 or B x 1 x 3
916
+ - world2local_rot : B x 3 x 3 or B x 1 x 3 x 3
917
+ - trans2joint : B x 1 x 1 x 3
918
+ """
919
+ B = world2local_trans.size(0)
920
+ world2local_rot = world2local_rot.reshape((B, 1, 3, 3))
921
+ world2local_trans = world2local_trans.reshape((B, 1, 3))
922
+ trans2joint = trans2joint.reshape((B, 1, 1, 3))
923
+ if invert:
924
+ local2world_rot = world2local_rot.transpose(3, 2)
925
+ for k, v in input_dict.items():
926
+ # apply differently depending on which data value it is
927
+ if k not in WORLD2ALIGN_NAME_CACHE:
928
+ # frame of reference is irrelevant, just copy to output
929
+ output_dict[k] = input_dict[k]
930
+ continue
931
+
932
+ S = input_dict[k].size(1)
933
+ if k in ["root_orient"]:
934
+ # rot: B x S x 3 x 3 sized rotation matrix input
935
+ input_mat = input_dict[k].reshape(
936
+ (B, S, 3, 3)
937
+ ) # make sure not B x S x 9
938
+ if invert:
939
+ output_dict[k] = torch.matmul(local2world_rot, input_mat).reshape(
940
+ (B, S, 9)
941
+ )
942
+ else:
943
+ output_dict[k] = torch.matmul(world2local_rot, input_mat).reshape(
944
+ (B, S, 9)
945
+ )
946
+ elif k in ["trans"]:
947
+ # trans + rot : B x S x 3
948
+ input_trans = input_dict[k]
949
+ if invert:
950
+ output_trans = torch.matmul(
951
+ local2world_rot, input_trans.reshape((B, S, 3, 1))
952
+ )[:, :, :, 0]
953
+ output_trans = output_trans - world2local_trans
954
+ output_dict[k] = output_trans
955
+ else:
956
+ input_trans = input_trans + world2local_trans
957
+ output_dict[k] = torch.matmul(
958
+ world2local_rot, input_trans.reshape((B, S, 3, 1))
959
+ )[:, :, :, 0]
960
+ elif k in ["joints", "verts"]:
961
+ # trans + joint + rot : B x S x J x 3
962
+ J = input_dict[k].size(2) // 3
963
+ input_pts = input_dict[k].reshape((B, S, J, 3))
964
+ if invert:
965
+ input_pts = input_pts + trans2joint
966
+ output_pts = torch.matmul(
967
+ local2world_rot.reshape((B, 1, 1, 3, 3)),
968
+ input_pts.reshape((B, S, J, 3, 1)),
969
+ )[:, :, :, :, 0]
970
+ output_pts = (
971
+ output_pts
972
+ - trans2joint
973
+ - world2local_trans.reshape((B, 1, 1, 3))
974
+ )
975
+ output_dict[k] = output_pts.reshape((B, S, J * 3))
976
+ else:
977
+ input_pts = (
978
+ input_pts
979
+ + world2local_trans.reshape((B, 1, 1, 3))
980
+ + trans2joint
981
+ )
982
+ output_pts = torch.matmul(
983
+ world2local_rot.reshape((B, 1, 1, 3, 3)),
984
+ input_pts.reshape((B, S, J, 3, 1)),
985
+ )[:, :, :, :, 0]
986
+ output_pts = output_pts - trans2joint
987
+ output_dict[k] = output_pts.reshape((B, S, J * 3))
988
+ elif k in ["joints_vel", "verts_vel"]:
989
+ # rot : B x S x J x 3
990
+ J = input_dict[k].size(2) // 3
991
+ input_pts = input_dict[k].reshape((B, S, J, 3, 1))
992
+ if invert:
993
+ outuput_pts = torch.matmul(
994
+ local2world_rot.reshape((B, 1, 1, 3, 3)), input_pts
995
+ )[:, :, :, :, 0]
996
+ output_dict[k] = outuput_pts.reshape((B, S, J * 3))
997
+ else:
998
+ output_pts = torch.matmul(
999
+ world2local_rot.reshape((B, 1, 1, 3, 3)), input_pts
1000
+ )[:, :, :, :, 0]
1001
+ output_dict[k] = output_pts.reshape((B, S, J * 3))
1002
+ elif k in ["trans_vel", "root_orient_vel"]:
1003
+ # rot : B x S x 3
1004
+ input_pts = input_dict[k].reshape((B, S, 3, 1))
1005
+ if invert:
1006
+ output_dict[k] = torch.matmul(local2world_rot, input_pts)[
1007
+ :, :, :, 0
1008
+ ]
1009
+ else:
1010
+ output_dict[k] = torch.matmul(world2local_rot, input_pts)[
1011
+ :, :, :, 0
1012
+ ]
1013
+ else:
1014
+ print(
1015
+ "Received an unexpected key when transforming world2local: %s!"
1016
+ % (k)
1017
+ )
1018
+ exit()
1019
+
1020
+ return output_dict
1021
+
1022
+ def zero_pad_tensors(self, pad_list, pad_size):
1023
+ """
1024
+ Assumes tensors in pad_list are B x D
1025
+ """
1026
+ new_pad_list = []
1027
+ for pad_idx, pad_tensor in enumerate(pad_list):
1028
+ padding = torch.zeros((pad_size, pad_tensor.size(1))).to(pad_tensor)
1029
+ new_pad_list.append(torch.cat([pad_tensor, padding], dim=0))
1030
+ return new_pad_list
1031
+
1032
+ def roll_out(
1033
+ self,
1034
+ x_past,
1035
+ init_input_dict,
1036
+ num_steps,
1037
+ use_mean=False,
1038
+ z_seq=None,
1039
+ return_prior=False,
1040
+ gender=None,
1041
+ betas=None,
1042
+ return_z=False,
1043
+ canonicalize_input=False,
1044
+ uncanonicalize_output=False,
1045
+ ):
1046
+ """
1047
+ Given input for first step, roll out using own output the entire time by sampling from the prior.
1048
+ Returns the global trajectory.
1049
+
1050
+ Input:
1051
+ - x_past (B x steps_in x D_in)
1052
+ - initial_input_dict : dictionary of each initial state (B x steps_in x D), rotations should be matrices
1053
+ (assumes initial state is already in its local coordinate system (translation at [0,0,z] and aligned))
1054
+ - num_steps : the number of timesteps to roll out
1055
+ - use_mean : if True, uses the mean of latent distribution instead of sampling
1056
+ - z_seq : (B x steps_out x D) if given, uses as the latent input to decoder at each step rather than sampling
1057
+ - return_prior : if True, also returns the output of the conditional prior at each step
1058
+ -gender : list of e.g. ['male', 'female', etc..] of length B
1059
+ -betas : B x steps_in x D
1060
+ -return_z : returns the sampled z sequence in addition to the output
1061
+ - canonicalize_input : if true, the input initial state is assumed to not be in the local aligned coordinate system. It will be transformed before using.
1062
+ - uncanonicalize_output : if true and canonicalize_input=True, will transform output back into the input frame rather than return in canonical frame.
1063
+ Returns:
1064
+ - x_pred - dict of (B x num_steps x D_out) for each value. Rotations are all matrices.
1065
+ """
1066
+ J = len(SMPL_JOINTS)
1067
+ cur_input_dict = init_input_dict
1068
+
1069
+ # need to transform init input to local frame
1070
+ world2aligned_rot = world2aligned_trans = None
1071
+ if canonicalize_input:
1072
+ B, _, _ = cur_input_dict[list(cur_input_dict.keys())[0]].size()
1073
+ # must transform initial input into the local frame
1074
+ # get world2aligned rot and translation
1075
+ root_orient_mat = cur_input_dict["root_orient"]
1076
+ pose_body_mat = cur_input_dict["pose_body"]
1077
+ if "root_orient" in self.data_names and self.in_rot_rep != "mat":
1078
+ root_orient_mat = convert_to_rotmat(
1079
+ root_orient_mat, rep=self.in_rot_rep
1080
+ )
1081
+ if "pose_body" in self.data_names and self.in_rot_rep != "mat":
1082
+ pose_body_mat = convert_to_rotmat(pose_body_mat, rep=self.in_rot_rep)
1083
+
1084
+ root_orient_mat = root_orient_mat[:, -1].reshape((B, 3, 3))
1085
+ world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
1086
+ world2aligned_trans = torch.cat(
1087
+ [
1088
+ -cur_input_dict["trans"][:, -1, :2],
1089
+ torch.zeros((B, 1)).to(root_orient_mat),
1090
+ ],
1091
+ axis=1,
1092
+ )
1093
+
1094
+ # compute trans2joint
1095
+ if self.need_trans2joint:
1096
+ trans2joint = -(
1097
+ cur_input_dict["joints"][:, -1, :2] + world2aligned_trans[:, :2]
1098
+ )
1099
+ trans2joint = torch.cat(
1100
+ [trans2joint, torch.zeros((B, 1)).to(trans2joint)], axis=1
1101
+ ).reshape((B, 1, 1, 3))
1102
+
1103
+ # transform to local frame
1104
+ cur_input_dict = self.apply_world2local_trans(
1105
+ world2aligned_trans,
1106
+ world2aligned_rot,
1107
+ trans2joint,
1108
+ cur_input_dict,
1109
+ cur_input_dict,
1110
+ invert=False,
1111
+ )
1112
+
1113
+ # check to make sure we have enough input steps, if not, pad
1114
+ pad_x_past = x_past is not None and x_past.size(1) < self.steps_in
1115
+ pad_in_dict = (
1116
+ cur_input_dict[list(cur_input_dict.keys())[0]].size(1) < self.steps_in
1117
+ )
1118
+ if pad_x_past:
1119
+ num_pad_steps = self.steps_in - x_past.size(1)
1120
+ cur_padding = torch.zeros(
1121
+ (x_past.size(0), num_pad_steps, x_past.size(2))
1122
+ ).to(
1123
+ x_past
1124
+ ) # assuming all data is B x T x D
1125
+ x_past = torch.cat([cur_padding, x_past], axis=1)
1126
+ if pad_in_dict:
1127
+ for k in cur_input_dict.keys():
1128
+ cur_in_dat = cur_input_dict[k]
1129
+ num_pad_steps = self.steps_in - cur_in_dat.size(1)
1130
+ cur_padding = torch.zeros(
1131
+ (cur_in_dat.size(0), num_pad_steps, cur_in_dat.size(2))
1132
+ ).to(
1133
+ cur_in_dat
1134
+ ) # assuming all data is B x T x D
1135
+ padded_in_dat = torch.cat([cur_padding, cur_in_dat], axis=1)
1136
+ cur_input_dict[k] = padded_in_dat
1137
+
1138
+ if x_past is None or canonicalize_input:
1139
+ x_past = [cur_input_dict[k] for k in self.data_names]
1140
+ x_past = torch.cat(x_past, axis=2)
1141
+ B, S, D = x_past.size()
1142
+ past_in = x_past.reshape((B, -1))
1143
+
1144
+ global_world2local_rot = (
1145
+ torch.eye(3).reshape((1, 1, 3, 3)).expand((B, 1, 3, 3)).to(x_past)
1146
+ )
1147
+ global_world2local_trans = torch.zeros((B, 1, 3)).to(x_past)
1148
+ if canonicalize_input and uncanonicalize_output:
1149
+ global_world2local_rot = world2aligned_rot.unsqueeze(1)
1150
+ global_world2local_trans = world2aligned_trans.unsqueeze(1)
1151
+ trans2joint = torch.zeros((B, 1, 1, 3)).to(x_past)
1152
+ if self.need_trans2joint:
1153
+ trans2joint = -torch.cat(
1154
+ [cur_input_dict["joints"][:, -1, :2], torch.zeros((B, 1)).to(x_past)],
1155
+ axis=1,
1156
+ ).reshape(
1157
+ (B, 1, 1, 3)
1158
+ ) # same for whole sequence
1159
+ pred_local_seq = []
1160
+ pred_global_seq = []
1161
+ prior_seq = []
1162
+ z_out_seq = []
1163
+ for t in range(num_steps):
1164
+ x_pred_dict = None
1165
+ # sample next step
1166
+ z_in = None
1167
+ if z_seq is not None:
1168
+ z_in = z_seq[:, t]
1169
+ sample_out = self.sample_step(
1170
+ past_in,
1171
+ use_mean=use_mean,
1172
+ z=z_in,
1173
+ return_prior=return_prior,
1174
+ return_z=return_z,
1175
+ )
1176
+ if return_prior:
1177
+ prior_out = sample_out["prior"]
1178
+ prior_seq.append(prior_out)
1179
+ if return_z:
1180
+ z_out = sample_out["z"]
1181
+ z_out_seq.append(z_out)
1182
+ decoder_out = sample_out["decoder_out"]
1183
+
1184
+ # split output predictions and transform out rotations to matrices
1185
+ x_pred_dict = self.split_output(decoder_out, convert_rots=True)
1186
+ if self.steps_out > 1:
1187
+ for k in x_pred_dict.keys():
1188
+ # only want immediate next frame prediction
1189
+ x_pred_dict[k] = x_pred_dict[k][:, 0:1, :]
1190
+
1191
+ pred_local_seq.append(x_pred_dict)
1192
+
1193
+ # output is the actual regressed joints, but input to next step can use smpl joints
1194
+ x_pred_smpl_joints = None
1195
+ if self.use_smpl_joint_inputs and gender is not None and betas is not None:
1196
+ # this assumes the model is actually outputting everything we need to run SMPL
1197
+ # also assumes single output step
1198
+ smpl_trans = x_pred_dict["trans"].reshape((B, 3))
1199
+ smpl_root_orient = rotation_matrix_to_angle_axis(
1200
+ x_pred_dict["root_orient"].reshape((B, 3, 3))
1201
+ ).reshape((B, 3))
1202
+ smpl_betas = betas[:, 0, :]
1203
+ smpl_pose_body = rotation_matrix_to_angle_axis(
1204
+ x_pred_dict["pose_body"].reshape((B * (J - 1), 3, 3))
1205
+ ).reshape((B, (J - 1) * 3))
1206
+
1207
+ smpl_vals = [smpl_trans, smpl_root_orient, smpl_betas, smpl_pose_body]
1208
+ # each batch index may be a different gender
1209
+ gender_names = ["male", "female", "neutral"]
1210
+ pred_joints = []
1211
+ prev_nbidx = 0
1212
+ cat_idx_map = np.ones((B), dtype=np.int) * -1
1213
+ for gender_name in gender_names:
1214
+ gender_idx = np.array(gender) == gender_name
1215
+ nbidx = np.sum(gender_idx)
1216
+ cat_idx_map[gender_idx] = np.arange(
1217
+ prev_nbidx, prev_nbidx + nbidx, dtype=np.int
1218
+ )
1219
+ prev_nbidx += nbidx
1220
+
1221
+ gender_smpl_vals = [val[gender_idx] for val in smpl_vals]
1222
+
1223
+ # need to pad extra frames with zeros in case not as long as expected
1224
+ pad_size = self.smpl_batch_size - nbidx
1225
+ if pad_size == B:
1226
+ # skip if no frames for this gender
1227
+ continue
1228
+ pad_list = gender_smpl_vals
1229
+ if pad_size < 0:
1230
+ raise Exception(
1231
+ "SMPL model batch size not large enough to accomodate!"
1232
+ )
1233
+ elif pad_size > 0:
1234
+ pad_list = self.zero_pad_tensors(pad_list, pad_size)
1235
+
1236
+ # reconstruct SMPL
1237
+ cur_pred_trans, cur_pred_orient, cur_betas, cur_pred_pose = pad_list
1238
+ bm = self.bm_dict[gender_name]
1239
+ pred_body = bm(
1240
+ pose_body=cur_pred_pose,
1241
+ betas=cur_betas,
1242
+ root_orient=cur_pred_orient,
1243
+ trans=cur_pred_trans,
1244
+ )
1245
+ if pad_size > 0:
1246
+ pred_joints.append(pred_body.Jtr[:-pad_size])
1247
+ else:
1248
+ pred_joints.append(pred_body.Jtr)
1249
+
1250
+ # cat all genders and reorder to original batch ordering
1251
+ x_pred_smpl_joints = torch.cat(pred_joints, axis=0)[
1252
+ :, : len(SMPL_JOINTS), :
1253
+ ].reshape((B, 1, -1))
1254
+ x_pred_smpl_joints = x_pred_smpl_joints[cat_idx_map]
1255
+
1256
+ # prepare input to next step
1257
+ # update input dict with new frame
1258
+ del_keys = []
1259
+ for k in cur_input_dict.keys():
1260
+ if k in x_pred_dict:
1261
+ # drop oldest frame and add new prediction
1262
+ keep_frames = cur_input_dict[k][:, 1:, :]
1263
+ # print(keep_frames.size())
1264
+
1265
+ if (
1266
+ k == "joints"
1267
+ and self.use_smpl_joint_inputs
1268
+ and x_pred_smpl_joints is not None
1269
+ ):
1270
+ cur_input_dict[k] = torch.cat(
1271
+ [keep_frames, x_pred_smpl_joints], axis=1
1272
+ )
1273
+ else:
1274
+ cur_input_dict[k] = torch.cat(
1275
+ [keep_frames, x_pred_dict[k]], axis=1
1276
+ )
1277
+ else:
1278
+ del_keys.append(k)
1279
+ for k in del_keys:
1280
+ del cur_input_dict[k]
1281
+
1282
+ # get world2aligned rot and translation
1283
+ root_orient_mat = x_pred_dict["root_orient"][:, 0, :].reshape((B, 3, 3))
1284
+ world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
1285
+ world2aligned_trans = torch.cat(
1286
+ [-x_pred_dict["trans"][:, 0, :2], torch.zeros((B, 1)).to(x_past)],
1287
+ axis=1,
1288
+ )
1289
+
1290
+ #
1291
+ # transform inputs to this local frame (body pose is not affected) for next step
1292
+ #
1293
+ cur_input_dict = self.apply_world2local_trans(
1294
+ world2aligned_trans,
1295
+ world2aligned_rot,
1296
+ trans2joint,
1297
+ cur_input_dict,
1298
+ cur_input_dict,
1299
+ invert=False,
1300
+ )
1301
+
1302
+ # convert rots to correct input format
1303
+ if self.in_rot_rep == "aa":
1304
+ if "root_orient" in self.data_names:
1305
+ cur_input_dict["root_orient"] = rotation_matrix_to_angle_axis(
1306
+ cur_input_dict["root_orient"].reshape((B * S, 3, 3))
1307
+ ).reshape((B, S, 3))
1308
+ if "pose_body" in self.data_names:
1309
+ cur_input_dict["pose_body"] = rotation_matrix_to_angle_axis(
1310
+ cur_input_dict["pose_body"].reshape((B * S * (J - 1), 3, 3))
1311
+ ).reshape((B, S, (J - 1) * 3))
1312
+ elif self.in_rot_rep == "6d":
1313
+ if "root_orient" in self.data_names:
1314
+ cur_input_dict["root_orient"] = cur_input_dict["root_orient"][
1315
+ :, :, :6
1316
+ ]
1317
+ if "pose_body" in self.data_names:
1318
+ cur_input_dict["pose_body"] = (
1319
+ cur_input_dict["pose_body"]
1320
+ .reshape((B, S, J - 1, 9))[:, :, :, :6]
1321
+ .reshape((B, S, (J - 1) * 6))
1322
+ )
1323
+
1324
+ #
1325
+ # compute current world output and update world2local transform
1326
+ #
1327
+ cur_world_dict = dict()
1328
+ cur_world_dict = self.apply_world2local_trans(
1329
+ global_world2local_trans,
1330
+ global_world2local_rot,
1331
+ trans2joint,
1332
+ x_pred_dict,
1333
+ cur_world_dict,
1334
+ invert=True,
1335
+ )
1336
+ #
1337
+ # update world2local transform
1338
+ #
1339
+ global_world2local_trans = torch.cat(
1340
+ [
1341
+ -cur_world_dict["trans"][:, 0:1, :2],
1342
+ torch.zeros((B, 1, 1)).to(x_past),
1343
+ ],
1344
+ axis=2,
1345
+ )
1346
+ # print(world2aligned_rot)
1347
+ global_world2local_rot = torch.matmul(
1348
+ global_world2local_rot, world2aligned_rot.reshape((B, 1, 3, 3))
1349
+ )
1350
+
1351
+ pred_global_seq.append(cur_world_dict)
1352
+
1353
+ # cat all inputs together to form past_in
1354
+ in_data_list = []
1355
+ for k in self.data_names:
1356
+ in_data_list.append(cur_input_dict[k])
1357
+ past_in = torch.cat(in_data_list, axis=2)
1358
+ past_in = past_in.reshape((B, -1))
1359
+
1360
+ # aggregate global pred_seq
1361
+ pred_seq_out = dict()
1362
+ for k in pred_global_seq[0].keys():
1363
+ pred_seq_out[k] = torch.cat(
1364
+ [pred_global_seq[i][k] for i in range(len(pred_global_seq))], axis=1
1365
+ )
1366
+
1367
+ if return_z:
1368
+ z_out_seq = torch.stack(z_out_seq, dim=1)
1369
+ pred_seq_out["z"] = z_out_seq
1370
+
1371
+ if return_prior:
1372
+ pm = torch.stack([prior_seq[i][0] for i in range(len(prior_seq))], axis=1)
1373
+ pv = torch.stack([prior_seq[i][1] for i in range(len(prior_seq))], axis=1)
1374
+ return pred_seq_out, (pm, pv)
1375
+ else:
1376
+ return pred_seq_out
1377
+
1378
+ def sample_step(
1379
+ self,
1380
+ past_in,
1381
+ t_in=None,
1382
+ use_mean=False,
1383
+ z=None,
1384
+ return_prior=False,
1385
+ return_z=False,
1386
+ ):
1387
+ """
1388
+ Given past, samples next future state by sampling from prior or posterior and decoding.
1389
+ If z (B x D) is not None, uses the given z instead of sampling from posterior or prior
1390
+
1391
+ Returns:
1392
+ - decoder_out : (B x steps_out x D) output of the decoder for the immediate next step
1393
+ """
1394
+ B = past_in.size(0)
1395
+
1396
+ pm, pv = None, None
1397
+ if t_in is not None:
1398
+ # use past and future to encode latent transition
1399
+ pm, pv = self.posterior(past_in, t_in)
1400
+ else:
1401
+ # prior
1402
+ if self.use_conditional_prior:
1403
+ # predict prior based on past
1404
+ pm, pv = self.prior(past_in)
1405
+ else:
1406
+ # use standard normal
1407
+ pm, pv = torch.zeros((B, self.latent_size)).to(past_in), torch.ones(
1408
+ (B, self.latent_size)
1409
+ ).to(past_in)
1410
+
1411
+ # sample from distrib or use mean
1412
+ if z is None:
1413
+ if not use_mean:
1414
+ z = self.rsample(pm, pv)
1415
+ else:
1416
+ z = pm # NOTE: use mean
1417
+
1418
+ # decode to get next step
1419
+ decoder_out = self.decode(z, past_in)
1420
+ decoder_out = decoder_out.reshape(
1421
+ (B, self.steps_out, -1)
1422
+ ) # B x steps_out x D_out
1423
+
1424
+ out_dict = {"decoder_out": decoder_out}
1425
+ if return_prior:
1426
+ out_dict["prior"] = (pm, pv)
1427
+ if return_z:
1428
+ out_dict["z"] = z
1429
+
1430
+ return out_dict
1431
+
1432
+ def infer_global_seq(self, global_seq, full_forward_pass=False):
1433
+ """
1434
+ Given a sequence of global states, formats it (transform each step into local frame and makde B x steps_in x D)
1435
+ and runs inference (compute prior/posterior of z for the sequence).
1436
+
1437
+ If full_forward_pass is true, does an entire forward pass at each step rather than just inference.
1438
+
1439
+ Rotations should be in in_rot_rep format.
1440
+ """
1441
+ # used to compute output zero padding
1442
+ needed_future_steps = (self.steps_out - 1) * self.out_step_size
1443
+
1444
+ prior_m_seq = []
1445
+ prior_v_seq = []
1446
+ post_m_seq = []
1447
+ post_v_seq = []
1448
+ pred_dict_seq = []
1449
+ B, T, _ = global_seq[list(global_seq.keys())[0]].size()
1450
+ J = len(SMPL_JOINTS)
1451
+ trans2joint = None
1452
+ for t in range(T - 1):
1453
+ # get world2aligned rot and translation
1454
+ world2aligned_rot = world2aligned_trans = None
1455
+
1456
+ root_orient_mat = global_seq["root_orient"][:, t, :].reshape((B, 3, 3))
1457
+ world2aligned_rot = compute_world2aligned_mat(root_orient_mat)
1458
+ world2aligned_trans = torch.cat(
1459
+ [
1460
+ -global_seq["trans"][:, t, :2],
1461
+ torch.zeros((B, 1)).to(root_orient_mat),
1462
+ ],
1463
+ axis=1,
1464
+ )
1465
+
1466
+ # compute trans2joint at first step
1467
+ if t == 0 and self.need_trans2joint:
1468
+ trans2joint = -(
1469
+ global_seq["joints"][:, t, :2] + world2aligned_trans[:, :2]
1470
+ ) # we cannot make the assumption that the first frame is already canonical
1471
+ trans2joint = torch.cat(
1472
+ [trans2joint, torch.zeros((B, 1)).to(trans2joint)], axis=1
1473
+ ).reshape((B, 1, 1, 3))
1474
+
1475
+ # get current window
1476
+ cur_data_dict = dict()
1477
+ for k in global_seq.keys():
1478
+ # get in steps
1479
+ in_sidx = max(0, t - self.steps_in + 1)
1480
+ cur_in_seq = global_seq[k][:, in_sidx : (t + 1), :]
1481
+ if cur_in_seq.size(1) < self.steps_in:
1482
+ # must zero pad front
1483
+ num_pad_steps = self.steps_in - cur_in_seq.size(1)
1484
+ cur_padding = torch.zeros(
1485
+ (cur_in_seq.size(0), num_pad_steps, cur_in_seq.size(2))
1486
+ ).to(
1487
+ cur_in_seq
1488
+ ) # assuming all data is B x T x D
1489
+ cur_in_seq = torch.cat([cur_padding, cur_in_seq], axis=1)
1490
+
1491
+ # get out steps
1492
+ cur_out_seq = global_seq[k][
1493
+ :, (t + 1) : (t + 2 + needed_future_steps) : self.out_step_size
1494
+ ]
1495
+ if cur_out_seq.size(1) < self.steps_out:
1496
+ # zero pad
1497
+ num_pad_steps = self.steps_out - cur_out_seq.size(1)
1498
+ cur_padding = torch.zeros_like(cur_out_seq[:, 0])
1499
+ cur_padding = torch.stack([cur_padding] * num_pad_steps, axis=1)
1500
+ cur_out_seq = torch.cat([cur_out_seq, cur_padding], axis=1)
1501
+ cur_data_dict[k] = torch.cat([cur_in_seq, cur_out_seq], axis=1)
1502
+
1503
+ # transform to local frame
1504
+ cur_data_dict = self.apply_world2local_trans(
1505
+ world2aligned_trans,
1506
+ world2aligned_rot,
1507
+ trans2joint,
1508
+ cur_data_dict,
1509
+ cur_data_dict,
1510
+ invert=False,
1511
+ )
1512
+
1513
+ # create x_past and x_t
1514
+ # cat all inputs together to form past_in
1515
+ in_data_list = []
1516
+ for k in self.data_names:
1517
+ in_data_list.append(cur_data_dict[k][:, : self.steps_in, :])
1518
+ x_past = torch.cat(in_data_list, axis=2)
1519
+ # cat all outputs together to form x_t
1520
+ out_data_list = []
1521
+ for k in self.data_names:
1522
+ out_data_list.append(cur_data_dict[k][:, self.steps_in :, :])
1523
+ x_t = torch.cat(out_data_list, axis=2)
1524
+
1525
+ if full_forward_pass:
1526
+ x_pred_dict = self(x_past, x_t)
1527
+ pred_dict_seq.append(x_pred_dict)
1528
+ else:
1529
+ # perform inference
1530
+ prior_z, posterior_z = self.infer(x_past, x_t)
1531
+ # save z
1532
+ prior_m_seq.append(prior_z[0])
1533
+ prior_v_seq.append(prior_z[1])
1534
+ post_m_seq.append(posterior_z[0])
1535
+ post_v_seq.append(posterior_z[1])
1536
+
1537
+ if full_forward_pass:
1538
+ # pred_dict_seq
1539
+ pred_seq_out = dict()
1540
+ for k in pred_dict_seq[0].keys():
1541
+ # print(k)
1542
+ if k == "posterior_distrib" or k == "prior_distrib":
1543
+ m = torch.stack(
1544
+ [pred_dict_seq[i][k][0] for i in range(len(pred_dict_seq))],
1545
+ axis=1,
1546
+ )
1547
+ v = torch.stack(
1548
+ [pred_dict_seq[i][k][1] for i in range(len(pred_dict_seq))],
1549
+ axis=1,
1550
+ )
1551
+ pred_seq_out[k] = (m, v)
1552
+ else:
1553
+ pred_seq_out[k] = torch.stack(
1554
+ [pred_dict_seq[i][k] for i in range(len(pred_dict_seq))], axis=1
1555
+ )
1556
+
1557
+ return pred_seq_out
1558
+ else:
1559
+ prior_m_seq = torch.stack(prior_m_seq, axis=1)
1560
+ prior_v_seq = torch.stack(prior_v_seq, axis=1)
1561
+ post_m_seq = torch.stack(post_m_seq, axis=1)
1562
+ post_v_seq = torch.stack(post_v_seq, axis=1)
1563
+
1564
+ return (prior_m_seq, prior_v_seq), (post_m_seq, post_v_seq)
1565
+
1566
+ def infer(self, x_past, x_t):
1567
+ """
1568
+ Inference (compute prior and posterior distribution of z) for a batch of single steps.
1569
+ NOTE: must do processing before passing in to ensure correct format that this function expects.
1570
+
1571
+ Input:
1572
+ - x_past (B x steps_in x D)
1573
+ - x_t (B x steps_out x D)
1574
+
1575
+ Returns:
1576
+ - prior_distrib (mu, var)
1577
+ - posterior_distrib (mu, var)
1578
+ """
1579
+
1580
+ B, _, D = x_past.size()
1581
+ past_in = x_past.reshape((B, -1))
1582
+ t_in = x_t.reshape((B, -1))
1583
+
1584
+ prior_z, posterior_z = self.infer_step(past_in, t_in)
1585
+
1586
+ return prior_z, posterior_z
1587
+
1588
+ def infer_step(self, past_in, t_in):
1589
+ """
1590
+ single step that computes both prior and posterior for training. Samples from posterior
1591
+ """
1592
+ B = past_in.size(0)
1593
+ # use past and future to encode latent transition
1594
+ qm, qv = self.posterior(past_in, t_in)
1595
+
1596
+ # prior
1597
+ pm, pv = None, None
1598
+ if self.use_conditional_prior:
1599
+ # predict prior based on past
1600
+ pm, pv = self.prior(past_in)
1601
+ else:
1602
+ # use standard normal
1603
+ pm, pv = torch.zeros_like(qm), torch.ones_like(qv)
1604
+
1605
+ return (pm, pv), (qm, qv)
1606
+
1607
+
1608
+ class MLP(nn.Module):
1609
+ def __init__(
1610
+ self,
1611
+ layers=[3, 128, 128, 3],
1612
+ nonlinearity=nn.ReLU,
1613
+ use_gn=True,
1614
+ skip_input_idx=None,
1615
+ ):
1616
+ """
1617
+ If skip_input_idx is not None, the input feature after idx skip_input_idx will be skip connected to every later of the MLP.
1618
+ """
1619
+ super(MLP, self).__init__()
1620
+
1621
+ in_size = layers[0]
1622
+ out_channels = layers[1:]
1623
+
1624
+ # input layer
1625
+ layers = []
1626
+ layers.append(nn.Linear(in_size, out_channels[0]))
1627
+ skip_size = 0 if skip_input_idx is None else (in_size - skip_input_idx)
1628
+ # now the rest
1629
+ for layer_idx in range(1, len(out_channels)):
1630
+ fc_layer = nn.Linear(
1631
+ out_channels[layer_idx - 1] + skip_size, out_channels[layer_idx]
1632
+ )
1633
+ if use_gn:
1634
+ bn_layer = nn.GroupNorm(16, out_channels[layer_idx - 1])
1635
+ layers.append(bn_layer)
1636
+ layers.extend([nonlinearity(), fc_layer])
1637
+ self.net = nn.ModuleList(layers)
1638
+ self.skip_input_idx = skip_input_idx
1639
+
1640
+ def forward(self, x):
1641
+ """
1642
+ B x D x * : batch norm done over dim D
1643
+ """
1644
+ skip_in = None
1645
+ if self.skip_input_idx is not None:
1646
+ skip_in = x[:, self.skip_input_idx :]
1647
+ for i, layer in enumerate(self.net):
1648
+ if (
1649
+ self.skip_input_idx is not None
1650
+ and i > 0
1651
+ and isinstance(layer, nn.Linear)
1652
+ ):
1653
+ x = torch.cat([x, skip_in], dim=1)
1654
+ x = layer(x)
1655
+ return x
slahmr/slahmr/humor/transforms.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/davrempe/humor
3
+ """
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import numpy as np
9
+ from torch.nn import functional as F
10
+
11
+ from body_model.utils import SMPL_JOINTS
12
+
13
+ #
14
+ # For computing local body frame
15
+ #
16
+
17
+ GLOB_DEVICE = (
18
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
19
+ )
20
+ XY_AXIS_GLOB = torch.Tensor([[1.0, 1.0, 0.0]]).to(device=GLOB_DEVICE)
21
+ X_AXIS_GLOB = torch.Tensor([[1.0, 0.0, 0.0]]).to(device=GLOB_DEVICE)
22
+
23
+
24
+ def compute_aligned_from_right(body_right):
25
+ xy_axis = XY_AXIS_GLOB
26
+ x_axis = X_AXIS_GLOB
27
+
28
+ body_right_x_proj = body_right[:, 0:1] / (
29
+ torch.norm(body_right[:, :2], dim=1, keepdim=True) + 1e-6
30
+ )
31
+ body_right_x_proj = torch.clamp(
32
+ body_right_x_proj, min=-1.0, max=1.0
33
+ ) # avoid acos error
34
+
35
+ world2aligned_angle = torch.acos(
36
+ body_right_x_proj
37
+ ) # project to world x axis, and compute angle
38
+ body_right = body_right * xy_axis
39
+ world2aligned_axis = torch.linalg.cross(body_right, x_axis.expand_as(body_right))
40
+
41
+ world2aligned_aa = (
42
+ world2aligned_axis
43
+ / (torch.norm(world2aligned_axis, dim=1, keepdim=True) + 1e-6)
44
+ ) * world2aligned_angle
45
+ world2aligned_mat = batch_rodrigues(world2aligned_aa)
46
+
47
+ return world2aligned_mat, world2aligned_aa
48
+
49
+
50
+ def compute_world2aligned_mat(rot_pos):
51
+ """
52
+ batch of world rotation matrices: B x 3 x 3
53
+ returns rot mats that align the inputs to the forward direction: B x 3 x 3
54
+ Torch version
55
+ """
56
+ body_right = -rot_pos[:, :, 0] # .clone() # in body coordinates body x-axis is left
57
+
58
+ world2aligned_mat, world2aligned_aa = compute_aligned_from_right(body_right)
59
+ return world2aligned_mat
60
+
61
+
62
+ def compute_world2aligned_joints_mat(joints):
63
+ """
64
+ Compute world to canonical frame (rotation around up axis)
65
+ from the given batch of joints (B x J x 3)
66
+ """
67
+ left_idx = SMPL_JOINTS["leftUpLeg"]
68
+ right_idx = SMPL_JOINTS["rightUpLeg"]
69
+
70
+ body_right = joints[:, right_idx] - joints[:, left_idx]
71
+ body_right = body_right / torch.norm(body_right, dim=1, keepdim=True)
72
+
73
+ world2aligned_mat, world2aligned_aa = compute_aligned_from_right(body_right)
74
+
75
+ return world2aligned_mat
76
+
77
+
78
+ def convert_to_rotmat(pred_rot, rep="aa"):
79
+ """
80
+ Converts rotation rep to rotation matrix based on the given type.
81
+ pred_rot : B x T x N
82
+ """
83
+ B, T, _ = pred_rot.size()
84
+ pred_rot_mat = None
85
+ if rep == "aa":
86
+ pred_rot_mat = batch_rodrigues(pred_rot.reshape(-1, 3))
87
+ elif rep == "6d":
88
+ pred_rot_mat = rot6d_to_rotmat(pred_rot.reshape(-1, 6))
89
+ elif rep == "9d":
90
+ pred_rot_mat = rot9d_to_rotmat(pred_rot.reshape(-1, 9))
91
+ return pred_rot_mat.reshape((B, T, -1))
92
+
93
+
94
+ #
95
+ # Many of these functions taken from https://github.com/mkocabas/VIBE/blob/a859e45a907379aa2fba65a7b620b4a2d65dcf1b/lib/utils/geometry.py
96
+ # Please see their license for usage restrictions.
97
+ #
98
+
99
+
100
+ def matrot2axisangle(matrots):
101
+ """
102
+ :param matrots: N*num_joints*9
103
+ :return: N*num_joints*3
104
+ """
105
+ import cv2
106
+
107
+ batch_size = matrots.shape[0]
108
+ matrots = matrots.reshape([batch_size, -1, 9])
109
+ out_axisangle = []
110
+ for mIdx in range(matrots.shape[0]):
111
+ cur_axisangle = []
112
+ for jIdx in range(matrots.shape[1]):
113
+ a = cv2.Rodrigues(matrots[mIdx, jIdx : jIdx + 1, :].reshape(3, 3))[
114
+ 0
115
+ ].reshape((1, 3))
116
+ cur_axisangle.append(a)
117
+
118
+ out_axisangle.append(np.array(cur_axisangle).reshape([1, -1, 3]))
119
+ return np.vstack(out_axisangle)
120
+
121
+
122
+ def axisangle2matrots(axisangle):
123
+ """
124
+ :param axisangle: N*num_joints*3
125
+ :return: N*num_joints*9
126
+ """
127
+ import cv2
128
+
129
+ batch_size = axisangle.shape[0]
130
+ axisangle = axisangle.reshape([batch_size, -1, 3])
131
+ out_matrot = []
132
+ for mIdx in range(axisangle.shape[0]):
133
+ cur_axisangle = []
134
+ for jIdx in range(axisangle.shape[1]):
135
+ a = cv2.Rodrigues(axisangle[mIdx, jIdx : jIdx + 1, :].reshape(1, 3))[0]
136
+ cur_axisangle.append(a)
137
+
138
+ out_matrot.append(np.array(cur_axisangle).reshape([1, -1, 9]))
139
+ return np.vstack(out_matrot)
140
+
141
+
142
+ def make_rot_homog(rotation_matrix):
143
+ if rotation_matrix.shape[1:] == (3, 3):
144
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
145
+ hom = (
146
+ torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
147
+ .reshape(1, 3, 1)
148
+ .expand(rot_mat.shape[0], -1, -1)
149
+ )
150
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
151
+ return rotation_matrix
152
+
153
+
154
+ def skew(v):
155
+ """
156
+ Returns skew symmetric (B x 3 x 3) mat from vector v: B x 3
157
+ """
158
+ B, D = v.size()
159
+ assert D == 3
160
+ skew_mat = torch.zeros((B, 3, 3)).to(v)
161
+ skew_mat[:, 0, 1] = v[:, 2]
162
+ skew_mat[:, 1, 0] = -v[:, 2]
163
+ skew_mat[:, 0, 2] = v[:, 1]
164
+ skew_mat[:, 2, 0] = -v[:, 1]
165
+ skew_mat[:, 1, 2] = v[:, 0]
166
+ skew_mat[:, 2, 1] = -v[:, 0]
167
+ return skew_mat
168
+
169
+
170
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
171
+ """Calculates the rotation matrices for a batch of rotation vectors
172
+ Parameters
173
+ ----------
174
+ rot_vecs: torch.tensor Nx3
175
+ array of N axis-angle vectors
176
+ Returns
177
+ -------
178
+ R: torch.tensor Nx3x3
179
+ The rotation matrices for the given axis-angle parameters
180
+ """
181
+
182
+ batch_size = rot_vecs.shape[0]
183
+ device = rot_vecs.device
184
+
185
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
186
+ rot_dir = rot_vecs / angle
187
+
188
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
189
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
190
+
191
+ # Bx1 arrays
192
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
193
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
194
+
195
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
196
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view(
197
+ (batch_size, 3, 3)
198
+ )
199
+
200
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
201
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
202
+ return rot_mat
203
+
204
+
205
+ def quat2mat(quat):
206
+ """
207
+ This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50
208
+ Convert quaternion coefficients to rotation matrix.
209
+ Args:
210
+ quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
211
+ Returns:
212
+ Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
213
+ """
214
+ norm_quat = quat
215
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
216
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
217
+
218
+ batch_size = quat.size(0)
219
+
220
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
221
+ wx, wy, wz = w * x, w * y, w * z
222
+ xy, xz, yz = x * y, x * z, y * z
223
+
224
+ rotMat = torch.stack(
225
+ [
226
+ w2 + x2 - y2 - z2,
227
+ 2 * xy - 2 * wz,
228
+ 2 * wy + 2 * xz,
229
+ 2 * wz + 2 * xy,
230
+ w2 - x2 + y2 - z2,
231
+ 2 * yz - 2 * wx,
232
+ 2 * xz - 2 * wy,
233
+ 2 * wx + 2 * yz,
234
+ w2 - x2 - y2 + z2,
235
+ ],
236
+ dim=1,
237
+ ).view(batch_size, 3, 3)
238
+ return rotMat
239
+
240
+
241
+ def rot6d_to_rotmat(x):
242
+ """Convert 6D rotation representation to 3x3 rotation matrix.
243
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
244
+ Input:
245
+ (B,6) Batch of 6-D rotation representations
246
+ Output:
247
+ (B,3,3) Batch of corresponding rotation matrices
248
+ """
249
+ x = x.view(-1, 3, 2)
250
+ a1 = x[:, :, 0]
251
+ a2 = x[:, :, 1]
252
+ b1 = F.normalize(a1)
253
+ b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
254
+
255
+ # inp = a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1
256
+ # denom = inp.pow(2).sum(dim=1).sqrt().unsqueeze(-1) + 1e-8
257
+ # b2 = inp / denom
258
+
259
+ b3 = torch.linalg.cross(b1, b2)
260
+ return torch.stack((b1, b2, b3), dim=-1)
261
+
262
+
263
+ def rot9d_to_rotmat(x):
264
+ """
265
+ Converts 9D rotation output to valid 3x3 rotation amtrix.
266
+ Based on Levinson et al., An Analysis of SVD for Deep Rotation Estimation.
267
+
268
+ Input:
269
+ (B, 9)
270
+ Output:
271
+ (B, 9)
272
+ """
273
+ B = x.size()[0]
274
+ x = x.reshape((B, 3, 3))
275
+ u, s, v = torch.svd(x)
276
+
277
+ v_T = v.transpose(-2, -1)
278
+ s_p = torch.eye(3).to(x).reshape((1, 3, 3)).expand_as(x).clone()
279
+ s_p[:, 2, 2] = torch.det(torch.matmul(u, v_T))
280
+ x_out = torch.matmul(torch.matmul(u, s_p), v_T)
281
+
282
+ return x_out.reshape((B, 9))
283
+
284
+
285
+ def rotation_matrix_to_angle_axis(rotation_matrix):
286
+ """
287
+ This function is borrowed from https://github.com/kornia/kornia
288
+ Convert 3x4 rotation matrix to Rodrigues vector
289
+ Args:
290
+ rotation_matrix (Tensor): rotation matrix.
291
+ Returns:
292
+ Tensor: Rodrigues vector transformation.
293
+ Shape:
294
+ - Input: :math:`(N, 3, 4)`
295
+ - Output: :math:`(N, 3)`
296
+ Example:
297
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
298
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
299
+ """
300
+ if rotation_matrix.shape[1:] == (3, 3):
301
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
302
+ hom = (
303
+ torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
304
+ .reshape(1, 3, 1)
305
+ .expand(rot_mat.shape[0], -1, -1)
306
+ )
307
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
308
+
309
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
310
+ aa = quaternion_to_angle_axis(quaternion)
311
+ aa[torch.isnan(aa)] = 0.0
312
+ return aa
313
+
314
+
315
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
316
+ """
317
+ This function is borrowed from https://github.com/kornia/kornia
318
+ Convert 3x4 rotation matrix to 4d quaternion vector
319
+ This algorithm is based on algorithm described in
320
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
321
+ Args:
322
+ rotation_matrix (Tensor): the rotation matrix to convert.
323
+ Return:
324
+ Tensor: the rotation in quaternion
325
+ Shape:
326
+ - Input: :math:`(N, 3, 4)`
327
+ - Output: :math:`(N, 4)`
328
+ Example:
329
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
330
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
331
+ """
332
+ if not torch.is_tensor(rotation_matrix):
333
+ raise TypeError(
334
+ "Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))
335
+ )
336
+
337
+ if len(rotation_matrix.shape) > 3:
338
+ raise ValueError(
339
+ "Input size must be a three dimensional tensor. Got {}".format(
340
+ rotation_matrix.shape
341
+ )
342
+ )
343
+ if not rotation_matrix.shape[-2:] == (3, 4):
344
+ raise ValueError(
345
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
346
+ rotation_matrix.shape
347
+ )
348
+ )
349
+
350
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
351
+
352
+ mask_d2 = rmat_t[:, 2, 2] < eps
353
+
354
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
355
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
356
+
357
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
358
+ q0 = torch.stack(
359
+ [
360
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
361
+ t0,
362
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
363
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
364
+ ],
365
+ -1,
366
+ )
367
+ t0_rep = t0.repeat(4, 1).t()
368
+
369
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
370
+ q1 = torch.stack(
371
+ [
372
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
373
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
374
+ t1,
375
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
376
+ ],
377
+ -1,
378
+ )
379
+ t1_rep = t1.repeat(4, 1).t()
380
+
381
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
382
+ q2 = torch.stack(
383
+ [
384
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
385
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
386
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
387
+ t2,
388
+ ],
389
+ -1,
390
+ )
391
+ t2_rep = t2.repeat(4, 1).t()
392
+
393
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
394
+ q3 = torch.stack(
395
+ [
396
+ t3,
397
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
398
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
399
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
400
+ ],
401
+ -1,
402
+ )
403
+ t3_rep = t3.repeat(4, 1).t()
404
+
405
+ mask_c0 = mask_d2 * mask_d0_d1
406
+ mask_c1 = mask_d2 * ~mask_d0_d1
407
+ mask_c2 = ~mask_d2 * mask_d0_nd1
408
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
409
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
410
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
411
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
412
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
413
+
414
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
415
+ q /= torch.sqrt(
416
+ t0_rep * mask_c0
417
+ + t1_rep * mask_c1
418
+ + t2_rep * mask_c2 # noqa
419
+ + t3_rep * mask_c3
420
+ ) # noqa
421
+ q *= 0.5
422
+ return q
423
+
424
+
425
+ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
426
+ """
427
+ This function is borrowed from https://github.com/kornia/kornia
428
+ Convert quaternion vector to angle axis of rotation.
429
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
430
+ Args:
431
+ quaternion (torch.Tensor): tensor with quaternions.
432
+ Return:
433
+ torch.Tensor: tensor with angle axis of rotation.
434
+ Shape:
435
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
436
+ - Output: :math:`(*, 3)`
437
+ Example:
438
+ >>> quaternion = torch.rand(2, 4) # Nx4
439
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
440
+ """
441
+ if not torch.is_tensor(quaternion):
442
+ raise TypeError(
443
+ "Input type is not a torch.Tensor. Got {}".format(type(quaternion))
444
+ )
445
+
446
+ if not quaternion.shape[-1] == 4:
447
+ raise ValueError(
448
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
449
+ )
450
+ # unpack input and compute conversion
451
+ q1: torch.Tensor = quaternion[..., 1]
452
+ q2: torch.Tensor = quaternion[..., 2]
453
+ q3: torch.Tensor = quaternion[..., 3]
454
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
455
+
456
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
457
+ cos_theta: torch.Tensor = quaternion[..., 0]
458
+ two_theta: torch.Tensor = 2.0 * torch.where(
459
+ cos_theta < 0.0,
460
+ torch.atan2(-sin_theta, -cos_theta),
461
+ torch.atan2(sin_theta, cos_theta),
462
+ )
463
+
464
+ k_pos: torch.Tensor = two_theta / sin_theta
465
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
466
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
467
+
468
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
469
+ angle_axis[..., 0] += q1 * k
470
+ angle_axis[..., 1] += q2 * k
471
+ angle_axis[..., 2] += q3 * k
472
+ return angle_axis
slahmr/slahmr/job_specs/3dpw_test_split.txt ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ downtown_arguing_00 data.start_idx=0 data.end_idx=100
2
+ downtown_arguing_00 data.start_idx=100 data.end_idx=200
3
+ downtown_arguing_00 data.start_idx=200 data.end_idx=300
4
+ downtown_arguing_00 data.start_idx=300 data.end_idx=400
5
+ downtown_arguing_00 data.start_idx=400 data.end_idx=500
6
+ downtown_arguing_00 data.start_idx=500 data.end_idx=600
7
+ downtown_arguing_00 data.start_idx=600 data.end_idx=700
8
+ downtown_arguing_00 data.start_idx=700 data.end_idx=898
9
+ downtown_bar_00 data.start_idx=0 data.end_idx=100
10
+ downtown_bar_00 data.start_idx=100 data.end_idx=200
11
+ downtown_bar_00 data.start_idx=200 data.end_idx=300
12
+ downtown_bar_00 data.start_idx=300 data.end_idx=400
13
+ downtown_bar_00 data.start_idx=400 data.end_idx=500
14
+ downtown_bar_00 data.start_idx=500 data.end_idx=600
15
+ downtown_bar_00 data.start_idx=600 data.end_idx=700
16
+ downtown_bar_00 data.start_idx=700 data.end_idx=800
17
+ downtown_bar_00 data.start_idx=800 data.end_idx=900
18
+ downtown_bar_00 data.start_idx=900 data.end_idx=1000
19
+ downtown_bar_00 data.start_idx=1000 data.end_idx=1100
20
+ downtown_bar_00 data.start_idx=1100 data.end_idx=1200
21
+ downtown_bar_00 data.start_idx=1200 data.end_idx=1300
22
+ downtown_bar_00 data.start_idx=1300 data.end_idx=1403
23
+ downtown_bus_00 data.start_idx=0 data.end_idx=100
24
+ downtown_bus_00 data.start_idx=100 data.end_idx=200
25
+ downtown_bus_00 data.start_idx=200 data.end_idx=300
26
+ downtown_bus_00 data.start_idx=300 data.end_idx=400
27
+ downtown_bus_00 data.start_idx=400 data.end_idx=500
28
+ downtown_bus_00 data.start_idx=500 data.end_idx=600
29
+ downtown_bus_00 data.start_idx=600 data.end_idx=700
30
+ downtown_bus_00 data.start_idx=700 data.end_idx=800
31
+ downtown_bus_00 data.start_idx=800 data.end_idx=900
32
+ downtown_bus_00 data.start_idx=900 data.end_idx=1000
33
+ downtown_bus_00 data.start_idx=1000 data.end_idx=1100
34
+ downtown_bus_00 data.start_idx=1100 data.end_idx=1200
35
+ downtown_bus_00 data.start_idx=1200 data.end_idx=1300
36
+ downtown_bus_00 data.start_idx=1300 data.end_idx=1400
37
+ downtown_bus_00 data.start_idx=1400 data.end_idx=1500
38
+ downtown_bus_00 data.start_idx=1500 data.end_idx=1600
39
+ downtown_bus_00 data.start_idx=1600 data.end_idx=1700
40
+ downtown_bus_00 data.start_idx=1700 data.end_idx=1800
41
+ downtown_bus_00 data.start_idx=1800 data.end_idx=1900
42
+ downtown_bus_00 data.start_idx=1900 data.end_idx=2000
43
+ downtown_bus_00 data.start_idx=2000 data.end_idx=2178
44
+ downtown_cafe_00 data.start_idx=0 data.end_idx=100
45
+ downtown_cafe_00 data.start_idx=100 data.end_idx=200
46
+ downtown_cafe_00 data.start_idx=200 data.end_idx=300
47
+ downtown_cafe_00 data.start_idx=300 data.end_idx=400
48
+ downtown_cafe_00 data.start_idx=400 data.end_idx=500
49
+ downtown_cafe_00 data.start_idx=500 data.end_idx=600
50
+ downtown_cafe_00 data.start_idx=600 data.end_idx=700
51
+ downtown_cafe_00 data.start_idx=700 data.end_idx=800
52
+ downtown_cafe_00 data.start_idx=800 data.end_idx=900
53
+ downtown_cafe_00 data.start_idx=900 data.end_idx=1000
54
+ downtown_cafe_00 data.start_idx=1000 data.end_idx=1100
55
+ downtown_cafe_00 data.start_idx=1100 data.end_idx=1201
56
+ downtown_car_00 data.start_idx=0 data.end_idx=100
57
+ downtown_car_00 data.start_idx=100 data.end_idx=200
58
+ downtown_car_00 data.start_idx=200 data.end_idx=300
59
+ downtown_car_00 data.start_idx=300 data.end_idx=400
60
+ downtown_car_00 data.start_idx=400 data.end_idx=500
61
+ downtown_car_00 data.start_idx=500 data.end_idx=600
62
+ downtown_car_00 data.start_idx=600 data.end_idx=700
63
+ downtown_car_00 data.start_idx=700 data.end_idx=800
64
+ downtown_car_00 data.start_idx=800 data.end_idx=900
65
+ downtown_car_00 data.start_idx=900 data.end_idx=1020
66
+ downtown_crossStreets_00 data.start_idx=0 data.end_idx=100
67
+ downtown_crossStreets_00 data.start_idx=100 data.end_idx=200
68
+ downtown_crossStreets_00 data.start_idx=200 data.end_idx=300
69
+ downtown_crossStreets_00 data.start_idx=300 data.end_idx=400
70
+ downtown_crossStreets_00 data.start_idx=400 data.end_idx=588
71
+ downtown_downstairs_00 data.start_idx=0 data.end_idx=100
72
+ downtown_downstairs_00 data.start_idx=100 data.end_idx=200
73
+ downtown_downstairs_00 data.start_idx=200 data.end_idx=300
74
+ downtown_downstairs_00 data.start_idx=300 data.end_idx=400
75
+ downtown_downstairs_00 data.start_idx=400 data.end_idx=500
76
+ downtown_downstairs_00 data.start_idx=500 data.end_idx=600
77
+ downtown_downstairs_00 data.start_idx=600 data.end_idx=700
78
+ downtown_downstairs_00 data.start_idx=700 data.end_idx=857
79
+ downtown_enterShop_00 data.start_idx=0 data.end_idx=100
80
+ downtown_enterShop_00 data.start_idx=100 data.end_idx=200
81
+ downtown_enterShop_00 data.start_idx=200 data.end_idx=300
82
+ downtown_enterShop_00 data.start_idx=300 data.end_idx=400
83
+ downtown_enterShop_00 data.start_idx=400 data.end_idx=500
84
+ downtown_enterShop_00 data.start_idx=500 data.end_idx=600
85
+ downtown_enterShop_00 data.start_idx=600 data.end_idx=700
86
+ downtown_enterShop_00 data.start_idx=700 data.end_idx=800
87
+ downtown_enterShop_00 data.start_idx=800 data.end_idx=900
88
+ downtown_enterShop_00 data.start_idx=900 data.end_idx=1000
89
+ downtown_enterShop_00 data.start_idx=1000 data.end_idx=1100
90
+ downtown_enterShop_00 data.start_idx=1100 data.end_idx=1200
91
+ downtown_enterShop_00 data.start_idx=1200 data.end_idx=1300
92
+ downtown_enterShop_00 data.start_idx=1300 data.end_idx=1449
93
+ downtown_rampAndStairs_00 data.start_idx=0 data.end_idx=100
94
+ downtown_rampAndStairs_00 data.start_idx=100 data.end_idx=200
95
+ downtown_rampAndStairs_00 data.start_idx=200 data.end_idx=300
96
+ downtown_rampAndStairs_00 data.start_idx=300 data.end_idx=400
97
+ downtown_rampAndStairs_00 data.start_idx=400 data.end_idx=500
98
+ downtown_rampAndStairs_00 data.start_idx=500 data.end_idx=600
99
+ downtown_rampAndStairs_00 data.start_idx=600 data.end_idx=700
100
+ downtown_rampAndStairs_00 data.start_idx=700 data.end_idx=800
101
+ downtown_rampAndStairs_00 data.start_idx=800 data.end_idx=984
102
+ downtown_runForBus_00 data.start_idx=0 data.end_idx=100
103
+ downtown_runForBus_00 data.start_idx=100 data.end_idx=200
104
+ downtown_runForBus_00 data.start_idx=200 data.end_idx=300
105
+ downtown_runForBus_00 data.start_idx=300 data.end_idx=400
106
+ downtown_runForBus_00 data.start_idx=400 data.end_idx=500
107
+ downtown_runForBus_00 data.start_idx=500 data.end_idx=600
108
+ downtown_runForBus_00 data.start_idx=600 data.end_idx=731
109
+ downtown_runForBus_01 data.start_idx=0 data.end_idx=100
110
+ downtown_runForBus_01 data.start_idx=100 data.end_idx=200
111
+ downtown_runForBus_01 data.start_idx=200 data.end_idx=300
112
+ downtown_runForBus_01 data.start_idx=300 data.end_idx=400
113
+ downtown_runForBus_01 data.start_idx=400 data.end_idx=500
114
+ downtown_runForBus_01 data.start_idx=500 data.end_idx=600
115
+ downtown_runForBus_01 data.start_idx=600 data.end_idx=783
116
+ downtown_sitOnStairs_00 data.start_idx=0 data.end_idx=100
117
+ downtown_sitOnStairs_00 data.start_idx=100 data.end_idx=200
118
+ downtown_sitOnStairs_00 data.start_idx=200 data.end_idx=300
119
+ downtown_sitOnStairs_00 data.start_idx=300 data.end_idx=400
120
+ downtown_sitOnStairs_00 data.start_idx=400 data.end_idx=500
121
+ downtown_sitOnStairs_00 data.start_idx=500 data.end_idx=600
122
+ downtown_sitOnStairs_00 data.start_idx=600 data.end_idx=700
123
+ downtown_sitOnStairs_00 data.start_idx=700 data.end_idx=800
124
+ downtown_sitOnStairs_00 data.start_idx=800 data.end_idx=900
125
+ downtown_sitOnStairs_00 data.start_idx=900 data.end_idx=1000
126
+ downtown_sitOnStairs_00 data.start_idx=1000 data.end_idx=1100
127
+ downtown_sitOnStairs_00 data.start_idx=1100 data.end_idx=1200
128
+ downtown_sitOnStairs_00 data.start_idx=1200 data.end_idx=1337
129
+ downtown_stairs_00 data.start_idx=0 data.end_idx=100
130
+ downtown_stairs_00 data.start_idx=100 data.end_idx=200
131
+ downtown_stairs_00 data.start_idx=200 data.end_idx=300
132
+ downtown_stairs_00 data.start_idx=300 data.end_idx=400
133
+ downtown_stairs_00 data.start_idx=400 data.end_idx=500
134
+ downtown_stairs_00 data.start_idx=500 data.end_idx=600
135
+ downtown_stairs_00 data.start_idx=600 data.end_idx=700
136
+ downtown_stairs_00 data.start_idx=700 data.end_idx=800
137
+ downtown_stairs_00 data.start_idx=800 data.end_idx=900
138
+ downtown_stairs_00 data.start_idx=900 data.end_idx=1000
139
+ downtown_stairs_00 data.start_idx=1000 data.end_idx=1100
140
+ downtown_stairs_00 data.start_idx=1100 data.end_idx=1240
141
+ downtown_upstairs_00 data.start_idx=0 data.end_idx=100
142
+ downtown_upstairs_00 data.start_idx=100 data.end_idx=200
143
+ downtown_upstairs_00 data.start_idx=200 data.end_idx=300
144
+ downtown_upstairs_00 data.start_idx=300 data.end_idx=400
145
+ downtown_upstairs_00 data.start_idx=400 data.end_idx=500
146
+ downtown_upstairs_00 data.start_idx=500 data.end_idx=600
147
+ downtown_upstairs_00 data.start_idx=600 data.end_idx=700
148
+ downtown_upstairs_00 data.start_idx=700 data.end_idx=845
149
+ downtown_walkBridge_01 data.start_idx=0 data.end_idx=100
150
+ downtown_walkBridge_01 data.start_idx=100 data.end_idx=200
151
+ downtown_walkBridge_01 data.start_idx=200 data.end_idx=300
152
+ downtown_walkBridge_01 data.start_idx=300 data.end_idx=400
153
+ downtown_walkBridge_01 data.start_idx=400 data.end_idx=500
154
+ downtown_walkBridge_01 data.start_idx=500 data.end_idx=600
155
+ downtown_walkBridge_01 data.start_idx=600 data.end_idx=700
156
+ downtown_walkBridge_01 data.start_idx=700 data.end_idx=800
157
+ downtown_walkBridge_01 data.start_idx=800 data.end_idx=900
158
+ downtown_walkBridge_01 data.start_idx=900 data.end_idx=1000
159
+ downtown_walkBridge_01 data.start_idx=1000 data.end_idx=1100
160
+ downtown_walkBridge_01 data.start_idx=1100 data.end_idx=1200
161
+ downtown_walkBridge_01 data.start_idx=1200 data.end_idx=1372
162
+ downtown_walkUphill_00 data.start_idx=0 data.end_idx=100
163
+ downtown_walkUphill_00 data.start_idx=100 data.end_idx=200
164
+ downtown_walkUphill_00 data.start_idx=200 data.end_idx=388
165
+ downtown_walking_00 data.start_idx=0 data.end_idx=100
166
+ downtown_walking_00 data.start_idx=100 data.end_idx=200
167
+ downtown_walking_00 data.start_idx=200 data.end_idx=300
168
+ downtown_walking_00 data.start_idx=300 data.end_idx=400
169
+ downtown_walking_00 data.start_idx=400 data.end_idx=500
170
+ downtown_walking_00 data.start_idx=500 data.end_idx=600
171
+ downtown_walking_00 data.start_idx=600 data.end_idx=700
172
+ downtown_walking_00 data.start_idx=700 data.end_idx=800
173
+ downtown_walking_00 data.start_idx=800 data.end_idx=900
174
+ downtown_walking_00 data.start_idx=900 data.end_idx=1000
175
+ downtown_walking_00 data.start_idx=1000 data.end_idx=1100
176
+ downtown_walking_00 data.start_idx=1100 data.end_idx=1200
177
+ downtown_walking_00 data.start_idx=1200 data.end_idx=1387
178
+ downtown_warmWelcome_00 data.start_idx=0 data.end_idx=100
179
+ downtown_warmWelcome_00 data.start_idx=100 data.end_idx=200
180
+ downtown_warmWelcome_00 data.start_idx=200 data.end_idx=300
181
+ downtown_warmWelcome_00 data.start_idx=300 data.end_idx=400
182
+ downtown_warmWelcome_00 data.start_idx=400 data.end_idx=589
183
+ downtown_weeklyMarket_00 data.start_idx=0 data.end_idx=100
184
+ downtown_weeklyMarket_00 data.start_idx=100 data.end_idx=200
185
+ downtown_weeklyMarket_00 data.start_idx=200 data.end_idx=300
186
+ downtown_weeklyMarket_00 data.start_idx=300 data.end_idx=400
187
+ downtown_weeklyMarket_00 data.start_idx=400 data.end_idx=500
188
+ downtown_weeklyMarket_00 data.start_idx=500 data.end_idx=600
189
+ downtown_weeklyMarket_00 data.start_idx=600 data.end_idx=700
190
+ downtown_weeklyMarket_00 data.start_idx=700 data.end_idx=800
191
+ downtown_weeklyMarket_00 data.start_idx=800 data.end_idx=900
192
+ downtown_weeklyMarket_00 data.start_idx=900 data.end_idx=1000
193
+ downtown_weeklyMarket_00 data.start_idx=1000 data.end_idx=1193
194
+ downtown_windowShopping_00 data.start_idx=0 data.end_idx=100
195
+ downtown_windowShopping_00 data.start_idx=100 data.end_idx=200
196
+ downtown_windowShopping_00 data.start_idx=200 data.end_idx=300
197
+ downtown_windowShopping_00 data.start_idx=300 data.end_idx=400
198
+ downtown_windowShopping_00 data.start_idx=400 data.end_idx=500
199
+ downtown_windowShopping_00 data.start_idx=500 data.end_idx=600
200
+ downtown_windowShopping_00 data.start_idx=600 data.end_idx=700
201
+ downtown_windowShopping_00 data.start_idx=700 data.end_idx=800
202
+ downtown_windowShopping_00 data.start_idx=800 data.end_idx=900
203
+ downtown_windowShopping_00 data.start_idx=900 data.end_idx=1000
204
+ downtown_windowShopping_00 data.start_idx=1000 data.end_idx=1100
205
+ downtown_windowShopping_00 data.start_idx=1100 data.end_idx=1200
206
+ downtown_windowShopping_00 data.start_idx=1200 data.end_idx=1300
207
+ downtown_windowShopping_00 data.start_idx=1300 data.end_idx=1400
208
+ downtown_windowShopping_00 data.start_idx=1400 data.end_idx=1500
209
+ downtown_windowShopping_00 data.start_idx=1500 data.end_idx=1600
210
+ downtown_windowShopping_00 data.start_idx=1600 data.end_idx=1700
211
+ downtown_windowShopping_00 data.start_idx=1700 data.end_idx=1800
212
+ downtown_windowShopping_00 data.start_idx=1800 data.end_idx=1948
213
+ flat_guitar_01 data.start_idx=0 data.end_idx=100
214
+ flat_guitar_01 data.start_idx=100 data.end_idx=200
215
+ flat_guitar_01 data.start_idx=200 data.end_idx=300
216
+ flat_guitar_01 data.start_idx=300 data.end_idx=400
217
+ flat_guitar_01 data.start_idx=400 data.end_idx=500
218
+ flat_guitar_01 data.start_idx=500 data.end_idx=600
219
+ flat_guitar_01 data.start_idx=600 data.end_idx=748
220
+ flat_packBags_00 data.start_idx=0 data.end_idx=100
221
+ flat_packBags_00 data.start_idx=100 data.end_idx=200
222
+ flat_packBags_00 data.start_idx=200 data.end_idx=300
223
+ flat_packBags_00 data.start_idx=300 data.end_idx=400
224
+ flat_packBags_00 data.start_idx=400 data.end_idx=500
225
+ flat_packBags_00 data.start_idx=500 data.end_idx=600
226
+ flat_packBags_00 data.start_idx=600 data.end_idx=700
227
+ flat_packBags_00 data.start_idx=700 data.end_idx=800
228
+ flat_packBags_00 data.start_idx=800 data.end_idx=900
229
+ flat_packBags_00 data.start_idx=900 data.end_idx=1000
230
+ flat_packBags_00 data.start_idx=1000 data.end_idx=1100
231
+ flat_packBags_00 data.start_idx=1100 data.end_idx=1279
232
+ office_phoneCall_00 data.start_idx=0 data.end_idx=100
233
+ office_phoneCall_00 data.start_idx=100 data.end_idx=200
234
+ office_phoneCall_00 data.start_idx=200 data.end_idx=300
235
+ office_phoneCall_00 data.start_idx=300 data.end_idx=400
236
+ office_phoneCall_00 data.start_idx=400 data.end_idx=500
237
+ office_phoneCall_00 data.start_idx=500 data.end_idx=600
238
+ office_phoneCall_00 data.start_idx=600 data.end_idx=700
239
+ office_phoneCall_00 data.start_idx=700 data.end_idx=880
240
+ outdoors_fencing_01 data.start_idx=0 data.end_idx=100
241
+ outdoors_fencing_01 data.start_idx=100 data.end_idx=200
242
+ outdoors_fencing_01 data.start_idx=200 data.end_idx=300
243
+ outdoors_fencing_01 data.start_idx=300 data.end_idx=400
244
+ outdoors_fencing_01 data.start_idx=400 data.end_idx=500
245
+ outdoors_fencing_01 data.start_idx=500 data.end_idx=600
246
+ outdoors_fencing_01 data.start_idx=600 data.end_idx=700
247
+ outdoors_fencing_01 data.start_idx=700 data.end_idx=800
248
+ outdoors_fencing_01 data.start_idx=800 data.end_idx=942
slahmr/slahmr/job_specs/davis.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parkour data=davis fps=24
2
+ lady-running data=davis fps=24
3
+ dance-twirl data=davis fps=24
4
+ lindy-hop data=davis fps=24
5
+ hike data=davis fps=24
6
+ judo data=davis fps=24
7
+ lucia data=davis fps=24
8
+ tennis data=davis fps=24
9
+ skate-park data=davis fps=24
10
+ boxing-fisheye data=davis fps=24
11
+ crossing data=davis fps=24
12
+ loading data=davis fps=24
13
+ bike-packing data=davis fps=24
14
+ dance-jump data=davis fps=24
15
+ hockey data=davis fps=24
16
+ india data=davis fps=24
17
+ kid-football data=davis fps=24
18
+ longboard data=davis fps=24
19
+ schoolgirls data=davis fps=24
20
+ snowboard data=davis fps=24
21
+ stunt data=davis fps=24
22
+ swing data=davis fps=24
23
+ dancing data=davis fps=24
24
+ kite-walk data=davis fps=24