Spaces:
Sleeping
Sleeping
chingshuai commited on
Commit ·
76957e3
1
Parent(s): 2fd136b
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- .gitignore +26 -0
- README.md +91 -7
- assets/arch.png +3 -0
- assets/banner.png +3 -0
- assets/config_simplified.yml +37 -0
- assets/pipeline.png +3 -0
- assets/sotacomp.png +3 -0
- assets/teaser.png +3 -0
- assets/wooden_models/boy_Rigging_smplx_tex.fbx +3 -0
- examples/example_prompts/example_subset.json +61 -0
- gradio_app.py +894 -0
- hymotion/network/attention.py +110 -0
- hymotion/network/bricks.py +46 -0
- hymotion/network/encoders.py +121 -0
- hymotion/network/hymotion_mmdit.py +636 -0
- hymotion/network/modulate_layers.py +49 -0
- hymotion/network/positional_encoding.py +174 -0
- hymotion/network/text_encoders/model_constants.py +8 -0
- hymotion/network/text_encoders/text_encoder.py +293 -0
- hymotion/network/token_refiner.py +192 -0
- hymotion/pipeline/body_model.py +412 -0
- hymotion/pipeline/motion_diffusion.py +639 -0
- hymotion/prompt_engineering/model_constants.py +42 -0
- hymotion/prompt_engineering/prompt_rewrite.py +284 -0
- hymotion/utils/configs.py +344 -0
- hymotion/utils/geometry.py +856 -0
- hymotion/utils/loaders.py +184 -0
- hymotion/utils/misc.py +113 -0
- hymotion/utils/motion_process.py +63 -0
- hymotion/utils/path.py +168 -0
- hymotion/utils/smplh2woodfbx.py +626 -0
- hymotion/utils/t2m_runtime.py +400 -0
- hymotion/utils/type_converter.py +22 -0
- hymotion/utils/visualize_mesh_web.py +463 -0
- requirements.txt +24 -0
- scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp +3 -0
- scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp +3 -0
- scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp +3 -0
- scripts/gradio/static/assets/dump_wooden/faces.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/j_template.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/joint_names.json +54 -0
- scripts/gradio/static/assets/dump_wooden/joints.ply +0 -0
- scripts/gradio/static/assets/dump_wooden/keypoints.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/kintree.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/skinIndice.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/skinWeights.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/uvs.bin +3 -0
- scripts/gradio/static/assets/dump_wooden/v_template.bin +3 -0
- scripts/gradio/templates/index_wooden_static.html +1205 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
3rdparty/fbxsdkpy-2020.1.post2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/wooden_models/*.fbx filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/wooden_models/boy_Rigging_smplx_tex.fbm/Boy_lambert4_BaseColor.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cache_config.json
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
*.pyw
|
| 7 |
+
*.pyz
|
| 8 |
+
*.pywz
|
| 9 |
+
*.pyzw
|
| 10 |
+
*.pyzwz
|
| 11 |
+
cache
|
| 12 |
+
|
| 13 |
+
.vscode/
|
| 14 |
+
|
| 15 |
+
ckpts/*
|
| 16 |
+
!ckpts/README.md
|
| 17 |
+
assets/body_models/*
|
| 18 |
+
!assets/body_models/README.md
|
| 19 |
+
scripts/gradio/static/assets/dump_smplh
|
| 20 |
+
scripts/gradio/static/assets/export_wooden_to_js.py
|
| 21 |
+
|
| 22 |
+
test_*/
|
| 23 |
+
debug
|
| 24 |
+
tencent
|
| 25 |
+
output
|
| 26 |
+
assets/wooden_models/boy_Rigging_smplx_tex.fbm/*
|
README.md
CHANGED
|
@@ -1,12 +1,96 @@
|
|
| 1 |
---
|
| 2 |
-
title: HY
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: HY-Motion-1.0
|
| 3 |
+
emoji: 💃
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: gradio_app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: Text-to-3D and Image-to-3D Generation
|
| 11 |
---
|
| 12 |
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="./assets/banner.png" alt="Banner" width="100%">
|
| 16 |
+
</p>
|
| 17 |
+
|
| 18 |
+
<div align="center">
|
| 19 |
+
<a href="https://hunyuan.tencent.com/motion" target="_blank">
|
| 20 |
+
<img src="https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage" height="22px" alt="Official Site">
|
| 21 |
+
</a>
|
| 22 |
+
<a href="https://huggingface.co/spaces/tencent/HY-Motion-1.0" target="_blank">
|
| 23 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Demo-276cb4.svg" height="22px" alt="HuggingFace Space">
|
| 24 |
+
</a>
|
| 25 |
+
<a href="https://huggingface.co/tencent/HY-Motion-1.0" target="_blank">
|
| 26 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg" height="22px" alt="HuggingFace Models">
|
| 27 |
+
</a>
|
| 28 |
+
<a href="https://arxiv.org/pdf/2512.23464" target="_blank">
|
| 29 |
+
<img src="https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv" height="22px" alt="ArXiv Report">
|
| 30 |
+
</a>
|
| 31 |
+
<a href="https://x.com/TencentHunyuan" target="_blank">
|
| 32 |
+
<img src="https://img.shields.io/badge/Hunyuan-black.svg?logo=x" height="22px" alt="X (Twitter)">
|
| 33 |
+
</a>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# HY-Motion 1.0: Scaling Flow Matching Models for 3D Motion Generation
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
<p align="center">
|
| 41 |
+
<img src="./assets/teaser.png" alt="Teaser" width="90%">
|
| 42 |
+
</p>
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
## 🔥 News
|
| 46 |
+
- **Dec 30, 2025**: 🤗 We released the inference code and pretrained models of [HY-Motion 1.0](https://huggingface.co/tencent/HY-Motion-1.0). Please give it a try via our [HuggingFace Space](https://huggingface.co/spaces/tencent/HY-Motion-1.0) and our [Official Site](https://hunyuan.tencent.com/motion)!
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## **Introduction**
|
| 50 |
+
|
| 51 |
+
**HY-Motion 1.0** is a series of text-to-3D human motion generation models based on Diffusion Transformer (DiT) and Flow Matching. It allows developers to generate skeleton-based 3D character animations from simple text prompts, which can be directly integrated into various 3D animation pipelines. This model series is the first to scale DiT-based text-to-motion models to the billion-parameter level, achieving significant improvements in instruction-following capabilities and motion quality over existing open-source models.
|
| 52 |
+
|
| 53 |
+
### Key Features
|
| 54 |
+
- **State-of-the-Art Performance**: Achieves state-of-the-art performance in both instruction-following capability and generated motion quality.
|
| 55 |
+
|
| 56 |
+
- **Billion-Scale Models**: We are the first to successfully scale DiT-based models to the billion-parameter level for text-to-motion generation. This results in superior instruction understanding and following capabilities, outperforming comparable open-source models.
|
| 57 |
+
|
| 58 |
+
- **Advanced Three-Stage Training**: Our models are trained using a comprehensive three-stage process:
|
| 59 |
+
|
| 60 |
+
- *Large-Scale Pre-training*: Trained on over 3,000 hours of diverse motion data to learn a broad motion prior.
|
| 61 |
+
|
| 62 |
+
- *High-Quality Fine-tuning*: Fine-tuned on 400 hours of curated, high-quality 3D motion data to enhance motion detail and smoothness.
|
| 63 |
+
|
| 64 |
+
- *Reinforcement Learning*: Utilizes Reinforcement Learning from human feedback and reward models to further refine instruction-following and motion naturalness.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
<p align="center">
|
| 69 |
+
<img src="./assets/pipeline.png" alt="System Overview" width="100%">
|
| 70 |
+
</p>
|
| 71 |
+
|
| 72 |
+
<p align="center">
|
| 73 |
+
<img src="./assets/arch.png" alt="Architecture" width="100%">
|
| 74 |
+
</p>
|
| 75 |
+
|
| 76 |
+
<p align="center">
|
| 77 |
+
<img src="./assets/sotacomp.png" alt="ComparisonSoTA" width="100%">
|
| 78 |
+
</p>
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
## 🔗 BibTeX
|
| 82 |
+
|
| 83 |
+
If you found this repository helpful, please cite our reports:
|
| 84 |
+
|
| 85 |
+
```bibtex
|
| 86 |
+
@article{hymotion2025,
|
| 87 |
+
title={HY-Motion 1.0: Scaling Flow Matching Models for Text-To-Motion Generation},
|
| 88 |
+
author={Tencent Hunyuan 3D Digital Human Team},
|
| 89 |
+
journal={arXiv preprint arXiv:2512.23464},
|
| 90 |
+
year={2025}
|
| 91 |
+
}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Acknowledgements
|
| 95 |
+
|
| 96 |
+
We would like to thank the contributors to the [FLUX](https://github.com/black-forest-labs/flux), [diffusers](https://github.com/huggingface/diffusers), [HuggingFace](https://huggingface.co), [SMPL](https://smpl.is.tue.mpg.de/)/[SMPLH](https://mano.is.tue.mpg.de/), [CLIP](https://github.com/openai/CLIP), [Qwen3](https://github.com/QwenLM/Qwen3), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [kornia](https://github.com/kornia/kornia), [transforms3d](https://github.com/matthew-brett/transforms3d), [FBX-SDK](https://www.autodesk.com/developer-network/platform-technologies/fbx-sdk-2020-0), [GVHMR](https://zju3dv.github.io/gvhmr/), and [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo) repositories or tools, for their open research and exploration.
|
assets/arch.png
ADDED
|
Git LFS Details
|
assets/banner.png
ADDED
|
Git LFS Details
|
assets/config_simplified.yml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
network_module: hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT
|
| 2 |
+
network_module_args:
|
| 3 |
+
apply_rope_to_single_branch: false
|
| 4 |
+
ctxt_input_dim: 4096
|
| 5 |
+
dropout: 0.0
|
| 6 |
+
feat_dim: 1024
|
| 7 |
+
input_dim: 201
|
| 8 |
+
mask_mode: narrowband
|
| 9 |
+
mlp_ratio: 4.0
|
| 10 |
+
num_heads: 16
|
| 11 |
+
num_layers: 18
|
| 12 |
+
time_factor: 1000.0
|
| 13 |
+
vtxt_input_dim: 768
|
| 14 |
+
train_pipeline: hymotion/pipeline/motion_diffusion.MotionFlowMatching
|
| 15 |
+
train_pipeline_args:
|
| 16 |
+
enable_ctxt_null_feat: true
|
| 17 |
+
enable_special_game_feat: true
|
| 18 |
+
infer_noise_scheduler_cfg:
|
| 19 |
+
validation_steps: 50
|
| 20 |
+
losses_cfg:
|
| 21 |
+
recons:
|
| 22 |
+
name: SmoothL1Loss
|
| 23 |
+
weight: 1.0
|
| 24 |
+
noise_scheduler_cfg:
|
| 25 |
+
method: euler
|
| 26 |
+
output_mesh_fps: 30
|
| 27 |
+
random_generator_on_gpu: true
|
| 28 |
+
test_cfg:
|
| 29 |
+
mean_std_dir: ./stats/
|
| 30 |
+
text_guidance_scale: 5.0
|
| 31 |
+
text_encoder_cfg:
|
| 32 |
+
llm_type: qwen3
|
| 33 |
+
max_length_llm: 128
|
| 34 |
+
text_encoder_module: hymotion/network/text_encoders/text_encoder.HYTextModel
|
| 35 |
+
train_cfg:
|
| 36 |
+
cond_mask_prob: 0.1
|
| 37 |
+
train_frames: 360
|
assets/pipeline.png
ADDED
|
Git LFS Details
|
assets/sotacomp.png
ADDED
|
Git LFS Details
|
assets/teaser.png
ADDED
|
Git LFS Details
|
assets/wooden_models/boy_Rigging_smplx_tex.fbx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e1a4fc5b121d5fa61a631ee22ba360ca128279d794d1ed75b2acb9486e71cc8
|
| 3 |
+
size 16490768
|
examples/example_prompts/example_subset.json
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"test_prompts_subset": [
|
| 3 |
+
"A person jumps upward with both legs twice.#90#001",
|
| 4 |
+
"A person jumps on their right leg.#90#002",
|
| 5 |
+
"A person climbs upward, moving up the slope.#60#003",
|
| 6 |
+
"A person climbs an obstacle.#60#004",
|
| 7 |
+
"A person walks forward.#120#005",
|
| 8 |
+
"A person walks forward, moving arms and legs while looking left and right.#180#006",
|
| 9 |
+
"A person walks unsteadily, then slowly sits down.#150#007",
|
| 10 |
+
"A person turns backward 180 degrees, then walks forward.#120#008",
|
| 11 |
+
"A person walks in a catwalk style, swinging their left arm while placing their right hand on their hip.#180#009",
|
| 12 |
+
"A person squats down on tiptoe#120#010",
|
| 13 |
+
"A person sits down on a chair.#90#011",
|
| 14 |
+
"A person runs forward.#60#012",
|
| 15 |
+
"A person jumps up.#90#013",
|
| 16 |
+
"A person jumps forward lightly, taking two steps.#69#014",
|
| 17 |
+
"A person shoots a basketball.#60#015",
|
| 18 |
+
"A person finishes freestyle swimming, then surfaces.#120#016",
|
| 19 |
+
"A person swings a golf club, hitting the ball forward.#111#017",
|
| 20 |
+
"A person runs forward, then kicks a soccer ball.#60#018",
|
| 21 |
+
"A person walks on a tightrope.#180#019",
|
| 22 |
+
"A person performs a yoga camel pose, extending their back and lifting their chest.#210#020",
|
| 23 |
+
"A person performs a sit-up, holding their head with both hands.#150#021",
|
| 24 |
+
"A person performs a lunge stretch, hands on hips.#150#022",
|
| 25 |
+
"A person performs a deadlift, lifting a barbell from the ground.#150#023",
|
| 26 |
+
"A person marches in place, swinging their arms forward and backward.#210#024",
|
| 27 |
+
"A person perform a squat, not standing up#93#025",
|
| 28 |
+
"A person performs a squat#93#026",
|
| 29 |
+
"A person performs a front arm raise, then does a squat.#93#027",
|
| 30 |
+
"A person performs a squat, raising both arms forward.#240#028",
|
| 31 |
+
"A person does a squat, balling both hands into fists, lowering into a squat, then standing up.#195#029",
|
| 32 |
+
"A person plays the piano.#270#030",
|
| 33 |
+
"A person dances bachata, executing rhythmic hip movements and footwork.#240#031",
|
| 34 |
+
"A person plays the drums while sitting down, with wide, crossing arm movements.#90#032",
|
| 35 |
+
"A person plays the drums while sitting down, with arms spreading wide and then crossing over.#90#033",
|
| 36 |
+
"A person dances jazz, jumping rhythmically.#240#034",
|
| 37 |
+
"A person practices tai chi, performing slow, controlled movements.#270#035",
|
| 38 |
+
"A person waves their right hand, sitting on a beach chair.#71#036",
|
| 39 |
+
"A person was sweeping the floor with their head down.#180#037",
|
| 40 |
+
"A person picks up an object from ground#117#038",
|
| 41 |
+
"A person picks up an object from lower ground with two hands#99#039",
|
| 42 |
+
"A person picks up an object from lower ground with two hands, and lifts over head#126#040",
|
| 43 |
+
"A person speaks, gesturing with both hands.#75#041",
|
| 44 |
+
"A person lies on a bed, reading a book.#180#042",
|
| 45 |
+
"A person bends down to pick up an object, then stands up straight.#150#043",
|
| 46 |
+
"A person flips the wok#61#044",
|
| 47 |
+
"A person rolls over while lying down.#60#045",
|
| 48 |
+
"A person walks forward, holding a tray at shoulder height with one hand.#93#046",
|
| 49 |
+
"A person stands up from the chair, then stretches the arms.#300#047",
|
| 50 |
+
"A person turns to evade.#61#048",
|
| 51 |
+
"A person collapses to the ground after being hit.#60#049",
|
| 52 |
+
"A person swings a sword forward.#60#050",
|
| 53 |
+
"A person attacks, holding a shield in the right hand and a sword in the left.#45#051",
|
| 54 |
+
"A person walks like a zombie, dragging their feet forward.#120#052",
|
| 55 |
+
"A person performs a taekwondo kick, extending their leg forcefully.#60#053",
|
| 56 |
+
"A person blocks with a shield.#60#054",
|
| 57 |
+
"A person lifts a long gun, then walks forward slowly.#90#055",
|
| 58 |
+
"A person stumbles, being hit.#45#056",
|
| 59 |
+
"A person assumes a boxing stance, then shifts weight to the right and punches with the right hand.#60#057"
|
| 60 |
+
]
|
| 61 |
+
}
|
gradio_app.py
ADDED
|
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import codecs as cs
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import random
|
| 7 |
+
import re
|
| 8 |
+
import textwrap
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import torch
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
|
| 15 |
+
def try_to_download_model():
|
| 16 |
+
repo_id = "tencent/HY-Motion-1.0"
|
| 17 |
+
target_folder = "HY-Motion-1.0-Lite"
|
| 18 |
+
print(f">>> start download ", repo_id, target_folder)
|
| 19 |
+
local_dir = snapshot_download(
|
| 20 |
+
repo_id=repo_id,
|
| 21 |
+
allow_patterns=f"{target_folder}/*",
|
| 22 |
+
local_dir="./downloaded_models"
|
| 23 |
+
)
|
| 24 |
+
final_model_path = os.path.join(local_dir, target_folder)
|
| 25 |
+
print(f">>> Final model path: {final_model_path}")
|
| 26 |
+
return final_model_path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Import spaces for Hugging Face Zero GPU support
|
| 30 |
+
try:
|
| 31 |
+
import spaces
|
| 32 |
+
SPACES_AVAILABLE = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
SPACES_AVAILABLE = False
|
| 35 |
+
# Create a dummy decorator when spaces is not available
|
| 36 |
+
class spaces:
|
| 37 |
+
@staticmethod
|
| 38 |
+
def GPU(func=None, duration=None):
|
| 39 |
+
def decorator(fn):
|
| 40 |
+
return fn
|
| 41 |
+
if func is not None:
|
| 42 |
+
return func
|
| 43 |
+
return decorator
|
| 44 |
+
|
| 45 |
+
from hymotion.utils.t2m_runtime import T2MRuntime
|
| 46 |
+
|
| 47 |
+
NUM_WORKERS = torch.cuda.device_count() if torch.cuda.is_available() else 1
|
| 48 |
+
|
| 49 |
+
# Global runtime instance for Zero GPU lazy loading
|
| 50 |
+
_global_runtime = None
|
| 51 |
+
_global_args = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _init_runtime_if_needed():
|
| 55 |
+
"""Initialize runtime lazily for Zero GPU support."""
|
| 56 |
+
global _global_runtime, _global_args
|
| 57 |
+
if _global_runtime is not None:
|
| 58 |
+
return _global_runtime
|
| 59 |
+
|
| 60 |
+
if _global_args is None:
|
| 61 |
+
raise RuntimeError("Runtime args not set. Call set_runtime_args() first.")
|
| 62 |
+
|
| 63 |
+
args = _global_args
|
| 64 |
+
cfg = osp.join(args.model_path, "config.yml")
|
| 65 |
+
ckpt = osp.join(args.model_path, "latest.ckpt")
|
| 66 |
+
|
| 67 |
+
skip_model_loading = False
|
| 68 |
+
if not os.path.exists(ckpt):
|
| 69 |
+
print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
|
| 70 |
+
print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
|
| 71 |
+
skip_model_loading = True
|
| 72 |
+
|
| 73 |
+
print(">>> Initializing T2MRuntime...")
|
| 74 |
+
if "USE_HF_MODELS" not in os.environ:
|
| 75 |
+
os.environ["USE_HF_MODELS"] = "1"
|
| 76 |
+
|
| 77 |
+
skip_text = False
|
| 78 |
+
_global_runtime = T2MRuntime(
|
| 79 |
+
config_path=cfg,
|
| 80 |
+
ckpt_name=ckpt,
|
| 81 |
+
skip_text=skip_text,
|
| 82 |
+
device_ids=None,
|
| 83 |
+
prompt_engineering_host=args.prompt_engineering_host,
|
| 84 |
+
skip_model_loading=skip_model_loading,
|
| 85 |
+
)
|
| 86 |
+
return _global_runtime
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@spaces.GPU(duration=120)
|
| 90 |
+
def generate_motion_on_gpu(
|
| 91 |
+
text: str,
|
| 92 |
+
seeds_csv: str,
|
| 93 |
+
motion_duration: float,
|
| 94 |
+
cfg_scale: float,
|
| 95 |
+
output_format: str,
|
| 96 |
+
original_text: str,
|
| 97 |
+
output_dir: str,
|
| 98 |
+
) -> Tuple[str, List[str]]:
|
| 99 |
+
"""
|
| 100 |
+
GPU-decorated function for motion generation.
|
| 101 |
+
This function will request GPU allocation on Hugging Face Zero GPU.
|
| 102 |
+
"""
|
| 103 |
+
runtime = _init_runtime_if_needed()
|
| 104 |
+
|
| 105 |
+
html_content, fbx_files, _ = runtime.generate_motion(
|
| 106 |
+
text=text,
|
| 107 |
+
seeds_csv=seeds_csv,
|
| 108 |
+
duration=motion_duration,
|
| 109 |
+
cfg_scale=cfg_scale,
|
| 110 |
+
output_format=output_format,
|
| 111 |
+
original_text=original_text,
|
| 112 |
+
output_dir=output_dir,
|
| 113 |
+
)
|
| 114 |
+
return html_content, fbx_files
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# define data sources
|
| 118 |
+
DATA_SOURCES = {
|
| 119 |
+
"example_prompts": "examples/example_prompts/example_subset.json",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# create interface
|
| 123 |
+
APP_CSS = """
|
| 124 |
+
:root{
|
| 125 |
+
--primary-start:#667eea; --primary-end:#764ba2;
|
| 126 |
+
--secondary-start:#4facfe; --secondary-end:#00f2fe;
|
| 127 |
+
--accent-start:#f093fb; --accent-end:#f5576c;
|
| 128 |
+
--page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%);
|
| 129 |
+
--card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%);
|
| 130 |
+
--radius:12px;
|
| 131 |
+
--iframe-bg:#ffffff;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
/* Dark mode variables */
|
| 135 |
+
[data-theme="dark"], .dark {
|
| 136 |
+
--page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%);
|
| 137 |
+
--card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%);
|
| 138 |
+
--text-primary:#f7fafc;
|
| 139 |
+
--text-secondary:#e2e8f0;
|
| 140 |
+
--border-color:#4a5568;
|
| 141 |
+
--input-bg:#374151;
|
| 142 |
+
--input-border:#4a5568;
|
| 143 |
+
--iframe-bg:#1a1a2e;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
/* Page and card */
|
| 147 |
+
.gradio-container{
|
| 148 |
+
background:var(--page-bg) !important;
|
| 149 |
+
min-height:100vh !important;
|
| 150 |
+
color:var(--text-primary, #333) !important;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
.main-header{
|
| 154 |
+
background:transparent !important; border:none !important; box-shadow:none !important;
|
| 155 |
+
padding:0 !important; margin:10px 0 16px !important;
|
| 156 |
+
text-align:center !important;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.main-header h1, .main-header p, .main-header li {
|
| 160 |
+
color:var(--text-primary, #333) !important;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
.left-panel,.right-panel{
|
| 164 |
+
background:var(--card-bg) !important;
|
| 165 |
+
border:1px solid var(--border-color, #e9ecef) !important;
|
| 166 |
+
border-radius:15px !important;
|
| 167 |
+
box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
|
| 168 |
+
padding:24px !important;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.gradio-accordion{
|
| 172 |
+
border:1px solid var(--border-color, #e1e5e9) !important;
|
| 173 |
+
border-radius:var(--radius) !important;
|
| 174 |
+
margin:12px 0 !important; background:transparent !important;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
.gradio-accordion summary{
|
| 178 |
+
background:transparent !important;
|
| 179 |
+
padding:14px 18px !important;
|
| 180 |
+
font-weight:600 !important;
|
| 181 |
+
color:var(--text-primary, #495057) !important;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.gradio-group{
|
| 185 |
+
background:transparent !important; border:none !important;
|
| 186 |
+
border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/* Input class style - dark mode adaptation */
|
| 190 |
+
.gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{
|
| 191 |
+
border-radius:8px !important;
|
| 192 |
+
border:2px solid var(--input-border, #e9ecef) !important;
|
| 193 |
+
background:var(--input-bg, #fff) !important;
|
| 194 |
+
color:var(--text-primary, #333) !important;
|
| 195 |
+
transition:.2s all !important;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
.gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{
|
| 199 |
+
border-color:var(--primary-start) !important;
|
| 200 |
+
box-shadow:0 0 0 3px rgba(102,126,234,.1) !important;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
.gradio-slider input[type="range"]{
|
| 204 |
+
background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important;
|
| 205 |
+
border-radius:10px !important;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.gradio-checkbox input[type="checkbox"]{
|
| 209 |
+
border-radius:4px !important;
|
| 210 |
+
border:2px solid var(--input-border, #e9ecef) !important;
|
| 211 |
+
transition:.2s all !important;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
.gradio-checkbox input[type="checkbox"]:checked{
|
| 215 |
+
background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important;
|
| 216 |
+
border-color:var(--primary-start) !important;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/* Label text color adaptation */
|
| 220 |
+
.gradio-textbox label, .gradio-dropdown label, .gradio-slider label,
|
| 221 |
+
.gradio-checkbox label, .gradio-html label {
|
| 222 |
+
color:var(--text-primary, #333) !important;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
.gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info,
|
| 226 |
+
.gradio-checkbox .info {
|
| 227 |
+
color:var(--text-secondary, #666) !important;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/* Status information - dark mode adaptation */
|
| 231 |
+
.gradio-textbox[data-testid*="状态信息"] input{
|
| 232 |
+
background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important;
|
| 233 |
+
border:2px solid var(--input-border, #dee2e6) !important;
|
| 234 |
+
color:var(--text-primary, #495057) !important;
|
| 235 |
+
font-weight:500 !important;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
/* Button base class and variant */
|
| 239 |
+
.generate-button,.rewrite-button,.dice-button{
|
| 240 |
+
border:none !important; color:#fff !important; font-weight:600 !important;
|
| 241 |
+
border-radius:8px !important; transition:.3s all !important;
|
| 242 |
+
box-shadow:0 4px 15px rgba(0,0,0,.12) !important;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
.generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; }
|
| 246 |
+
.rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; }
|
| 247 |
+
.dice-button{
|
| 248 |
+
background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important;
|
| 249 |
+
height:40px !important;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; }
|
| 253 |
+
.dice-button:hover{
|
| 254 |
+
transform:scale(1.05) !important;
|
| 255 |
+
box-shadow:0 4px 12px rgba(240,147,251,.28) !important;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
.dice-container{
|
| 259 |
+
display:flex !important;
|
| 260 |
+
align-items:flex-end !important;
|
| 261 |
+
justify-content:center !important;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
/* Right panel clipping overflow, avoid double scrollbars */
|
| 265 |
+
.right-panel{
|
| 266 |
+
background:var(--card-bg) !important;
|
| 267 |
+
border:1px solid var(--border-color, #e9ecef) !important;
|
| 268 |
+
border-radius:15px !important;
|
| 269 |
+
box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
|
| 270 |
+
padding:24px !important; overflow:hidden !important;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/* Main content row - ensure equal heights */
|
| 274 |
+
.main-row {
|
| 275 |
+
display: flex !important;
|
| 276 |
+
align-items: stretch !important;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/* Flask area - match left panel height */
|
| 280 |
+
.flask-display{
|
| 281 |
+
padding:0 !important; margin:0 !important; border:none !important;
|
| 282 |
+
box-shadow:none !important; background:var(--iframe-bg) !important;
|
| 283 |
+
border-radius:10px !important; position:relative !important;
|
| 284 |
+
height:100% !important; min-height:750px !important;
|
| 285 |
+
display:flex !important; flex-direction:column !important;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
.flask-display iframe{
|
| 289 |
+
width:100% !important; flex:1 !important; min-height:750px !important;
|
| 290 |
+
border:none !important; border-radius:10px !important; display:block !important;
|
| 291 |
+
background:var(--iframe-bg) !important;
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
/* Right panel should stretch to match left panel */
|
| 295 |
+
.right-panel{
|
| 296 |
+
background:var(--card-bg) !important;
|
| 297 |
+
border:1px solid var(--border-color, #e9ecef) !important;
|
| 298 |
+
border-radius:15px !important;
|
| 299 |
+
box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
|
| 300 |
+
padding:24px !important; overflow:hidden !important;
|
| 301 |
+
display:flex !important; flex-direction:column !important;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/* Ensure dropdown menu is visible in dark mode */
|
| 305 |
+
[data-theme="dark"] .gradio-dropdown .wrap,
|
| 306 |
+
.dark .gradio-dropdown .wrap {
|
| 307 |
+
background:var(--input-bg) !important;
|
| 308 |
+
color:var(--text-primary) !important;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
[data-theme="dark"] .gradio-dropdown .option,
|
| 312 |
+
.dark .gradio-dropdown .option {
|
| 313 |
+
background:var(--input-bg) !important;
|
| 314 |
+
color:var(--text-primary) !important;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
[data-theme="dark"] .gradio-dropdown .option:hover,
|
| 318 |
+
.dark .gradio-dropdown .option:hover {
|
| 319 |
+
background:var(--border-color) !important;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
.footer{
|
| 323 |
+
text-align:center !important;
|
| 324 |
+
margin-top:20px !important;
|
| 325 |
+
padding:10px !important;
|
| 326 |
+
color:var(--text-secondary, #666) !important;
|
| 327 |
+
}
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground"
|
| 331 |
+
|
| 332 |
+
FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*"
|
| 333 |
+
|
| 334 |
+
HTML_OUTPUT_PLACEHOLDER = """
|
| 335 |
+
<div style='height: 750px; width: 100%; border-radius: 8px; border-color: #e5e7eb; border-style: solid; border-width: 1px; display: flex; justify-content: center; align-items: center;'>
|
| 336 |
+
<div style='text-align: center; font-size: 16px; color: #6b7280;'>
|
| 337 |
+
<p style="color: #8d8d8d;">Welcome to HY-Motion-1.0!</p>
|
| 338 |
+
<p style="color: #8d8d8d;">No motion visualization here yet.</p>
|
| 339 |
+
</div>
|
| 340 |
+
</div>
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
|
| 345 |
+
"""Load examples from txt file."""
|
| 346 |
+
|
| 347 |
+
def _parse_line(line: str) -> Optional[Tuple[str, float]]:
|
| 348 |
+
line = line.strip()
|
| 349 |
+
if line and not line.startswith("#"):
|
| 350 |
+
parts = line.split("#")
|
| 351 |
+
if len(parts) >= 2:
|
| 352 |
+
text = parts[0].strip()
|
| 353 |
+
duration = int(parts[1]) / example_record_fps
|
| 354 |
+
duration = min(duration, max_duration)
|
| 355 |
+
else:
|
| 356 |
+
text = line.strip()
|
| 357 |
+
duration = 5.0
|
| 358 |
+
return text, duration
|
| 359 |
+
return None
|
| 360 |
+
|
| 361 |
+
examples: List[Tuple[str, float]] = []
|
| 362 |
+
if os.path.exists(txt_path):
|
| 363 |
+
try:
|
| 364 |
+
if txt_path.endswith(".txt"):
|
| 365 |
+
with cs.open(txt_path, "r", encoding="utf-8") as f:
|
| 366 |
+
lines = f.readlines()
|
| 367 |
+
for line in lines:
|
| 368 |
+
result = _parse_line(line)
|
| 369 |
+
if result is None:
|
| 370 |
+
continue
|
| 371 |
+
text, duration = result
|
| 372 |
+
examples.append((text, duration))
|
| 373 |
+
elif txt_path.endswith(".json"):
|
| 374 |
+
with cs.open(txt_path, "r", encoding="utf-8") as f:
|
| 375 |
+
lines = json.load(f)
|
| 376 |
+
for key, value in lines.items():
|
| 377 |
+
if "_raw_chn" in key or "GENERATE_PROMPT_FORMAT" in key:
|
| 378 |
+
continue
|
| 379 |
+
for line in value:
|
| 380 |
+
result = _parse_line(line)
|
| 381 |
+
if result is None:
|
| 382 |
+
continue
|
| 383 |
+
text, duration = result
|
| 384 |
+
examples.append((text, duration))
|
| 385 |
+
print(f">>> Loaded {len(examples)} examples from {txt_path}")
|
| 386 |
+
except Exception as e:
|
| 387 |
+
print(f">>> Failed to load examples from {txt_path}: {e}")
|
| 388 |
+
else:
|
| 389 |
+
print(f">>> Examples file not found: {txt_path}")
|
| 390 |
+
|
| 391 |
+
return examples
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class T2MGradioUI:
|
| 395 |
+
def __init__(self, runtime: T2MRuntime, args: argparse.Namespace):
|
| 396 |
+
self.runtime = runtime
|
| 397 |
+
self.args = args
|
| 398 |
+
|
| 399 |
+
# Check if rewrite is available:
|
| 400 |
+
# - prompt_engineering_host must be provided
|
| 401 |
+
# - disable_rewrite must not be set
|
| 402 |
+
print(f">>> args: {vars(args)}")
|
| 403 |
+
self.rewrite_available = (
|
| 404 |
+
args.prompt_engineering_host is not None
|
| 405 |
+
and args.prompt_engineering_host.strip() != ""
|
| 406 |
+
and not args.disable_rewrite
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
self.all_example_data = {}
|
| 410 |
+
self._init_example_data()
|
| 411 |
+
|
| 412 |
+
def _init_example_data(self):
|
| 413 |
+
for source_name, file_path in DATA_SOURCES.items():
|
| 414 |
+
examples = load_examples_from_txt(file_path)
|
| 415 |
+
if examples:
|
| 416 |
+
self.all_example_data[source_name] = examples
|
| 417 |
+
else:
|
| 418 |
+
# provide default examples as fallback
|
| 419 |
+
self.all_example_data[source_name] = [
|
| 420 |
+
("Twist at the waist and punch across the body.", 3.0),
|
| 421 |
+
("A person is running then takes big leap.", 3.0),
|
| 422 |
+
("A person holds a railing and walks down a set of stairs.", 5.0),
|
| 423 |
+
(
|
| 424 |
+
"A man performs a fluid and rhythmic hip-hop style dance, incorporating body waves, arm gestures, and side steps.",
|
| 425 |
+
5.0,
|
| 426 |
+
),
|
| 427 |
+
]
|
| 428 |
+
print(f">>> Loaded data sources: {list(self.all_example_data.keys())}")
|
| 429 |
+
|
| 430 |
+
def _get_header_text(self):
|
| 431 |
+
return HEADER_BASE_MD
|
| 432 |
+
|
| 433 |
+
def _generate_random_seeds(self):
|
| 434 |
+
seeds = [random.randint(0, 999) for _ in range(4)]
|
| 435 |
+
return ",".join(map(str, seeds))
|
| 436 |
+
|
| 437 |
+
def _prompt_engineering(
|
| 438 |
+
self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True
|
| 439 |
+
):
|
| 440 |
+
if not text.strip():
|
| 441 |
+
return "", gr.update(interactive=False), gr.update()
|
| 442 |
+
|
| 443 |
+
call_llm = enable_rewrite or enable_duration_est
|
| 444 |
+
if not call_llm:
|
| 445 |
+
print(f"\t>>> Using original duration and original text...")
|
| 446 |
+
predicted_duration = duration
|
| 447 |
+
rewritten_text = text
|
| 448 |
+
else:
|
| 449 |
+
print(f"\t>>> Using LLM to estimate duration/rewrite text...")
|
| 450 |
+
try:
|
| 451 |
+
predicted_duration, rewritten_text = self.runtime.rewrite_text_and_infer_time(text=text)
|
| 452 |
+
except Exception as e:
|
| 453 |
+
print(f"\t>>> Text rewriting/duration prediction failed: {e}")
|
| 454 |
+
return (
|
| 455 |
+
f"❌ Text rewriting/duration prediction failed: {str(e)}",
|
| 456 |
+
gr.update(interactive=False),
|
| 457 |
+
gr.update(),
|
| 458 |
+
)
|
| 459 |
+
if not enable_rewrite:
|
| 460 |
+
rewritten_text = text
|
| 461 |
+
if not enable_duration_est:
|
| 462 |
+
predicted_duration = duration
|
| 463 |
+
|
| 464 |
+
return rewritten_text, gr.update(interactive=True), gr.update(value=predicted_duration)
|
| 465 |
+
|
| 466 |
+
def _generate_motion(
|
| 467 |
+
self,
|
| 468 |
+
original_text: str,
|
| 469 |
+
rewritten_text: str,
|
| 470 |
+
seed_input: str,
|
| 471 |
+
duration: float,
|
| 472 |
+
cfg_scale: float,
|
| 473 |
+
) -> Tuple[str, List[str]]:
|
| 474 |
+
# When rewrite is not available, use original_text directly
|
| 475 |
+
if not self.rewrite_available:
|
| 476 |
+
text_to_use = original_text.strip()
|
| 477 |
+
if not text_to_use:
|
| 478 |
+
return "Error: Input text is empty, please enter text first", []
|
| 479 |
+
else:
|
| 480 |
+
text_to_use = rewritten_text.strip()
|
| 481 |
+
if not text_to_use:
|
| 482 |
+
return "Error: Rewritten text is empty, please rewrite the text first", []
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
# Use runtime from global if available (for Zero GPU), otherwise use self.runtime
|
| 486 |
+
runtime = _global_runtime if _global_runtime is not None else self.runtime
|
| 487 |
+
fbx_ok = getattr(runtime, "fbx_available", False)
|
| 488 |
+
req_format = "fbx" if fbx_ok else "dict"
|
| 489 |
+
|
| 490 |
+
# Use GPU-decorated function for Zero GPU support
|
| 491 |
+
html_content, fbx_files = generate_motion_on_gpu(
|
| 492 |
+
text=text_to_use,
|
| 493 |
+
seeds_csv=seed_input,
|
| 494 |
+
motion_duration=duration,
|
| 495 |
+
cfg_scale=cfg_scale,
|
| 496 |
+
output_format=req_format,
|
| 497 |
+
original_text=original_text,
|
| 498 |
+
output_dir=self.args.output_dir,
|
| 499 |
+
)
|
| 500 |
+
# Escape HTML content for srcdoc attribute
|
| 501 |
+
escaped_html = html_content.replace('"', '"')
|
| 502 |
+
# Return iframe with srcdoc - directly embed HTML content
|
| 503 |
+
iframe_html = f'''
|
| 504 |
+
<iframe
|
| 505 |
+
srcdoc="{escaped_html}"
|
| 506 |
+
width="100%"
|
| 507 |
+
height="750px"
|
| 508 |
+
style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
|
| 509 |
+
></iframe>
|
| 510 |
+
'''
|
| 511 |
+
return iframe_html, fbx_files
|
| 512 |
+
except Exception as e:
|
| 513 |
+
print(f"\t>>> Motion generation failed: {e}")
|
| 514 |
+
return (
|
| 515 |
+
f"❌ Motion generation failed: {str(e)}\n\nPlease check the input parameters or try again later",
|
| 516 |
+
[],
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
def _get_example_choices(self):
|
| 520 |
+
"""Get all example choices from all data sources"""
|
| 521 |
+
choices = ["Custom Input"]
|
| 522 |
+
for source_name in self.all_example_data:
|
| 523 |
+
example_data = self.all_example_data[source_name]
|
| 524 |
+
for text, _ in example_data:
|
| 525 |
+
display_text = f"{text[:50]}..." if len(text) > 50 else text
|
| 526 |
+
choices.append(display_text)
|
| 527 |
+
return choices
|
| 528 |
+
|
| 529 |
+
def _on_example_select(self, selected_example):
|
| 530 |
+
"""When selecting an example, the callback function"""
|
| 531 |
+
if selected_example == "Custom Input":
|
| 532 |
+
return "", self._generate_random_seeds(), gr.update()
|
| 533 |
+
else:
|
| 534 |
+
# find the corresponding example from all data sources
|
| 535 |
+
for source_name in self.all_example_data:
|
| 536 |
+
example_data = self.all_example_data[source_name]
|
| 537 |
+
for text, duration in example_data:
|
| 538 |
+
display_text = f"{text[:50]}..." if len(text) > 50 else text
|
| 539 |
+
if display_text == selected_example:
|
| 540 |
+
return text, self._generate_random_seeds(), gr.update(value=duration)
|
| 541 |
+
return "", self._generate_random_seeds(), gr.update()
|
| 542 |
+
|
| 543 |
+
def build_ui(self):
|
| 544 |
+
with gr.Blocks(css=APP_CSS) as demo:
|
| 545 |
+
self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
|
| 546 |
+
|
| 547 |
+
with gr.Row():
|
| 548 |
+
# Left control panel
|
| 549 |
+
with gr.Column(scale=2, elem_classes=["left-panel"]):
|
| 550 |
+
# Input textbox
|
| 551 |
+
self.text_input = gr.Textbox(
|
| 552 |
+
label="📝 Input Text",
|
| 553 |
+
placeholder="Enter text to generate motion, support Chinese and English text input.",
|
| 554 |
+
)
|
| 555 |
+
# Rewritten textbox
|
| 556 |
+
self.rewritten_text = gr.Textbox(
|
| 557 |
+
label="✏️ Rewritten Text",
|
| 558 |
+
placeholder="Rewritten text will be displayed here, you can further edit",
|
| 559 |
+
interactive=True,
|
| 560 |
+
visible=False,
|
| 561 |
+
)
|
| 562 |
+
# Duration slider
|
| 563 |
+
self.duration_slider = gr.Slider(
|
| 564 |
+
minimum=0.5,
|
| 565 |
+
maximum=12,
|
| 566 |
+
value=5.0,
|
| 567 |
+
step=0.1,
|
| 568 |
+
label="⏱️ Action Duration (seconds)",
|
| 569 |
+
info="Feel free to adjust the action duration",
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Execute buttons
|
| 573 |
+
with gr.Row():
|
| 574 |
+
if self.rewrite_available:
|
| 575 |
+
self.rewrite_btn = gr.Button(
|
| 576 |
+
"🔄 Rewrite Text",
|
| 577 |
+
variant="secondary",
|
| 578 |
+
size="lg",
|
| 579 |
+
elem_classes=["rewrite-button"],
|
| 580 |
+
)
|
| 581 |
+
else:
|
| 582 |
+
# Create a hidden/disabled placeholder button
|
| 583 |
+
self.rewrite_btn = gr.Button(
|
| 584 |
+
"🔄 Rewrite Text (Unavailable)",
|
| 585 |
+
variant="secondary",
|
| 586 |
+
size="lg",
|
| 587 |
+
elem_classes=["rewrite-button"],
|
| 588 |
+
interactive=False,
|
| 589 |
+
visible=False,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
self.generate_btn = gr.Button(
|
| 593 |
+
"🚀 Generate Motion",
|
| 594 |
+
variant="primary",
|
| 595 |
+
size="lg",
|
| 596 |
+
elem_classes=["generate-button"],
|
| 597 |
+
interactive=not self.rewrite_available, # Enable directly if rewrite not available
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
if not self.rewrite_available:
|
| 601 |
+
gr.Markdown(
|
| 602 |
+
"> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Advanced settings
|
| 606 |
+
with gr.Accordion("🔧 Advanced Settings", open=False):
|
| 607 |
+
self._build_advanced_settings()
|
| 608 |
+
|
| 609 |
+
# Example selection dropdown
|
| 610 |
+
self.example_dropdown = gr.Dropdown(
|
| 611 |
+
choices=self._get_example_choices(),
|
| 612 |
+
value="Custom Input",
|
| 613 |
+
label="📚 Test Examples",
|
| 614 |
+
info="Select a preset example or input your own text above",
|
| 615 |
+
interactive=True,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Status message depends on whether rewrite is available
|
| 619 |
+
if self.rewrite_available:
|
| 620 |
+
status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first"
|
| 621 |
+
else:
|
| 622 |
+
status_msg = "Enter your text and click [🚀 Generate Motion] directly."
|
| 623 |
+
|
| 624 |
+
self.status_output = gr.Textbox(
|
| 625 |
+
label="📊 Status Information",
|
| 626 |
+
value=status_msg,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# FBX Download section
|
| 630 |
+
with gr.Row(visible=False) as self.fbx_download_row:
|
| 631 |
+
if getattr(self.runtime, "fbx_available", False):
|
| 632 |
+
self.fbx_files = gr.File(
|
| 633 |
+
label="📦 Download FBX Files",
|
| 634 |
+
file_count="multiple",
|
| 635 |
+
interactive=False,
|
| 636 |
+
)
|
| 637 |
+
else:
|
| 638 |
+
self.fbx_files = gr.State([])
|
| 639 |
+
|
| 640 |
+
# Right display area
|
| 641 |
+
with gr.Column(scale=3):
|
| 642 |
+
self.output_display = gr.HTML(
|
| 643 |
+
value=HTML_OUTPUT_PLACEHOLDER,
|
| 644 |
+
show_label=False,
|
| 645 |
+
elem_classes=["flask-display"]
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# Footer
|
| 649 |
+
gr.Markdown(FOOTER_MD, elem_classes=["footer"])
|
| 650 |
+
|
| 651 |
+
self._bind_events()
|
| 652 |
+
demo.load(fn=self._get_header_text, outputs=[self.header_md])
|
| 653 |
+
return demo
|
| 654 |
+
|
| 655 |
+
def _build_advanced_settings(self):
|
| 656 |
+
# Only show rewrite options if rewrite is available
|
| 657 |
+
if self.rewrite_available:
|
| 658 |
+
with gr.Group():
|
| 659 |
+
gr.Markdown("### 🔄 Text Rewriting Options")
|
| 660 |
+
with gr.Row():
|
| 661 |
+
self.enable_rewrite = gr.Checkbox(
|
| 662 |
+
label="Enable Text Rewriting",
|
| 663 |
+
value=True,
|
| 664 |
+
info="Automatically optimize text prompt to get better motion generation",
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
with gr.Group():
|
| 668 |
+
gr.Markdown("### ⏱️ Duration Settings")
|
| 669 |
+
self.enable_duration_est = gr.Checkbox(
|
| 670 |
+
label="Enable Duration Estimation",
|
| 671 |
+
value=True,
|
| 672 |
+
info="Automatically estimate the duration of the motion",
|
| 673 |
+
)
|
| 674 |
+
else:
|
| 675 |
+
# Create hidden placeholders with default values (disabled)
|
| 676 |
+
self.enable_rewrite = gr.Checkbox(
|
| 677 |
+
label="Enable Text Rewriting",
|
| 678 |
+
value=False,
|
| 679 |
+
visible=False,
|
| 680 |
+
)
|
| 681 |
+
self.enable_duration_est = gr.Checkbox(
|
| 682 |
+
label="Enable Duration Estimation",
|
| 683 |
+
value=False,
|
| 684 |
+
visible=False,
|
| 685 |
+
)
|
| 686 |
+
with gr.Group():
|
| 687 |
+
gr.Markdown("### ⚠️ Prompt Engineering Unavailable")
|
| 688 |
+
gr.Markdown(
|
| 689 |
+
"Text rewriting and duration estimation are not available. "
|
| 690 |
+
"Your input text and duration will be used directly."
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
with gr.Group():
|
| 694 |
+
gr.Markdown("### ⚙️ Generation Parameters")
|
| 695 |
+
with gr.Row():
|
| 696 |
+
with gr.Column(scale=3):
|
| 697 |
+
self.seed_input = gr.Textbox(
|
| 698 |
+
label="🎯 Random Seed List (comma separated)",
|
| 699 |
+
value="0,1,2,3",
|
| 700 |
+
placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)",
|
| 701 |
+
info="Random seeds control the diversity of generated motions",
|
| 702 |
+
)
|
| 703 |
+
with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]):
|
| 704 |
+
self.dice_btn = gr.Button(
|
| 705 |
+
"🎲 Lucky Button",
|
| 706 |
+
variant="secondary",
|
| 707 |
+
size="sm",
|
| 708 |
+
elem_classes=["dice-button"],
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
self.cfg_slider = gr.Slider(
|
| 712 |
+
minimum=1,
|
| 713 |
+
maximum=10,
|
| 714 |
+
value=5.0,
|
| 715 |
+
step=0.1,
|
| 716 |
+
label="⚙️ CFG Strength",
|
| 717 |
+
info="Text fidelity: higher = more faithful to the prompt",
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
def _bind_events(self):
|
| 721 |
+
# Generate random seeds
|
| 722 |
+
self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
|
| 723 |
+
|
| 724 |
+
# Bind example selection event
|
| 725 |
+
self.example_dropdown.change(
|
| 726 |
+
fn=self._on_example_select,
|
| 727 |
+
inputs=[self.example_dropdown],
|
| 728 |
+
outputs=[self.text_input, self.seed_input, self.duration_slider],
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
# Rewrite text logic (only bind when rewrite is available)
|
| 732 |
+
if self.rewrite_available:
|
| 733 |
+
self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then(
|
| 734 |
+
self._prompt_engineering,
|
| 735 |
+
inputs=[
|
| 736 |
+
self.text_input,
|
| 737 |
+
self.duration_slider,
|
| 738 |
+
self.enable_rewrite,
|
| 739 |
+
self.enable_duration_est,
|
| 740 |
+
],
|
| 741 |
+
outputs=[self.rewritten_text, self.generate_btn, self.duration_slider],
|
| 742 |
+
).then(
|
| 743 |
+
fn=lambda: (
|
| 744 |
+
gr.update(visible=True),
|
| 745 |
+
"Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
|
| 746 |
+
),
|
| 747 |
+
outputs=[self.rewritten_text, self.status_output],
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# Generate motion logic
|
| 751 |
+
self.generate_btn.click(
|
| 752 |
+
fn=lambda: "Generating motion, please wait... (It takes some extra time to start the renderer for the first generation)",
|
| 753 |
+
outputs=[self.status_output],
|
| 754 |
+
).then(
|
| 755 |
+
self._generate_motion,
|
| 756 |
+
inputs=[
|
| 757 |
+
self.text_input,
|
| 758 |
+
self.rewritten_text,
|
| 759 |
+
self.seed_input,
|
| 760 |
+
self.duration_slider,
|
| 761 |
+
self.cfg_slider,
|
| 762 |
+
],
|
| 763 |
+
outputs=[self.output_display, self.fbx_files],
|
| 764 |
+
concurrency_limit=NUM_WORKERS,
|
| 765 |
+
).then(
|
| 766 |
+
fn=lambda fbx_list: (
|
| 767 |
+
(
|
| 768 |
+
"🎉 Motion generation completed! You can view the motion visualization result on the right. FBX files are ready for download."
|
| 769 |
+
if fbx_list
|
| 770 |
+
else "🎉 Motion generation completed! You can view the motion visualization result on the right"
|
| 771 |
+
),
|
| 772 |
+
gr.update(visible=bool(fbx_list)),
|
| 773 |
+
),
|
| 774 |
+
inputs=[self.fbx_files],
|
| 775 |
+
outputs=[self.status_output, self.fbx_download_row],
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
# Reset logic - different behavior based on rewrite availability
|
| 779 |
+
if self.rewrite_available:
|
| 780 |
+
self.text_input.change(
|
| 781 |
+
fn=lambda: (
|
| 782 |
+
gr.update(visible=False),
|
| 783 |
+
gr.update(interactive=False),
|
| 784 |
+
"Please click the [🔄 Rewrite Text] button to rewrite the text first",
|
| 785 |
+
),
|
| 786 |
+
outputs=[self.rewritten_text, self.generate_btn, self.status_output],
|
| 787 |
+
)
|
| 788 |
+
else:
|
| 789 |
+
# When rewrite is not available, enable generate button directly when text is entered
|
| 790 |
+
self.text_input.change(
|
| 791 |
+
fn=lambda text: (
|
| 792 |
+
gr.update(visible=False),
|
| 793 |
+
gr.update(interactive=bool(text.strip())),
|
| 794 |
+
(
|
| 795 |
+
"Ready to generate! Click [🚀 Generate Motion] to start."
|
| 796 |
+
if text.strip()
|
| 797 |
+
else "Enter your text and click [🚀 Generate Motion] directly."
|
| 798 |
+
),
|
| 799 |
+
),
|
| 800 |
+
inputs=[self.text_input],
|
| 801 |
+
outputs=[self.rewritten_text, self.generate_btn, self.status_output],
|
| 802 |
+
)
|
| 803 |
+
# Only bind rewritten_text change when rewrite is available
|
| 804 |
+
if self.rewrite_available:
|
| 805 |
+
self.rewritten_text.change(
|
| 806 |
+
fn=lambda text: (
|
| 807 |
+
gr.update(interactive=bool(text.strip())),
|
| 808 |
+
(
|
| 809 |
+
"Rewritten text has been modified, you can click [🚀 Generate Motion]"
|
| 810 |
+
if text.strip()
|
| 811 |
+
else "Rewritten text cannot be empty, please enter valid text"
|
| 812 |
+
),
|
| 813 |
+
),
|
| 814 |
+
inputs=[self.rewritten_text],
|
| 815 |
+
outputs=[self.generate_btn, self.status_output],
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def create_demo(final_model_path):
|
| 820 |
+
"""Create the Gradio demo with Zero GPU support."""
|
| 821 |
+
global _global_runtime, _global_args
|
| 822 |
+
|
| 823 |
+
class Args:
|
| 824 |
+
model_path = final_model_path
|
| 825 |
+
output_dir = "output/gradio"
|
| 826 |
+
prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
|
| 827 |
+
disable_rewrite = False
|
| 828 |
+
|
| 829 |
+
args = Args()
|
| 830 |
+
_global_args = args # Set global args for lazy loading
|
| 831 |
+
|
| 832 |
+
# Check required files:
|
| 833 |
+
cfg = osp.join(args.model_path, "config.yml")
|
| 834 |
+
ckpt = osp.join(args.model_path, "latest.ckpt")
|
| 835 |
+
if not osp.exists(cfg):
|
| 836 |
+
raise FileNotFoundError(f">>> Configuration file not found: {cfg}")
|
| 837 |
+
|
| 838 |
+
# Create output directory
|
| 839 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 840 |
+
|
| 841 |
+
# For Zero GPU: Don't load model at startup, use lazy loading
|
| 842 |
+
# Create a minimal runtime for UI initialization (without model loading)
|
| 843 |
+
if SPACES_AVAILABLE:
|
| 844 |
+
print(">>> Hugging Face Spaces detected. Using Zero GPU lazy loading.")
|
| 845 |
+
print(">>> Model will be loaded on first GPU request.")
|
| 846 |
+
|
| 847 |
+
# Create a placeholder runtime with minimal initialization for UI
|
| 848 |
+
class PlaceholderRuntime:
|
| 849 |
+
def __init__(self):
|
| 850 |
+
self.fbx_available = False
|
| 851 |
+
self.prompt_engineering_host = args.prompt_engineering_host
|
| 852 |
+
|
| 853 |
+
def rewrite_text_and_infer_time(self, text: str):
|
| 854 |
+
# For prompt rewriting, we don't need GPU
|
| 855 |
+
from hymotion.prompt_engineering.prompt_rewrite import PromptRewriter
|
| 856 |
+
rewriter = PromptRewriter(host=self.prompt_engineering_host)
|
| 857 |
+
return rewriter.rewrite_prompt_and_infer_time(text)
|
| 858 |
+
|
| 859 |
+
runtime = PlaceholderRuntime()
|
| 860 |
+
else:
|
| 861 |
+
# Local development: load model immediately
|
| 862 |
+
print(">>> Local environment detected. Loading model at startup.")
|
| 863 |
+
skip_model_loading = False
|
| 864 |
+
if not os.path.exists(ckpt):
|
| 865 |
+
print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
|
| 866 |
+
print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
|
| 867 |
+
skip_model_loading = True
|
| 868 |
+
|
| 869 |
+
print(">>> Initializing T2MRuntime...")
|
| 870 |
+
if "USE_HF_MODELS" not in os.environ:
|
| 871 |
+
os.environ["USE_HF_MODELS"] = "1"
|
| 872 |
+
|
| 873 |
+
skip_text = False
|
| 874 |
+
runtime = T2MRuntime(
|
| 875 |
+
config_path=cfg,
|
| 876 |
+
ckpt_name=ckpt,
|
| 877 |
+
skip_text=skip_text,
|
| 878 |
+
device_ids=None,
|
| 879 |
+
prompt_engineering_host=args.prompt_engineering_host,
|
| 880 |
+
skip_model_loading=skip_model_loading,
|
| 881 |
+
)
|
| 882 |
+
_global_runtime = runtime # Set global runtime for GPU function
|
| 883 |
+
|
| 884 |
+
ui = T2MGradioUI(runtime=runtime, args=args)
|
| 885 |
+
demo = ui.build_ui()
|
| 886 |
+
return demo
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
# Create demo at module level for Hugging Face Spaces
|
| 890 |
+
final_model_path = try_to_download_model()
|
| 891 |
+
demo = create_demo(final_model_path)
|
| 892 |
+
|
| 893 |
+
if __name__ == "__main__":
|
| 894 |
+
demo.launch()
|
hymotion/network/attention.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import flash_attn
|
| 10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
|
| 11 |
+
except ImportError:
|
| 12 |
+
flash_attn = None
|
| 13 |
+
flash_attn_varlen_func = None
|
| 14 |
+
_flash_attn_forward = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MEMORY_LAYOUT = {
|
| 18 |
+
"flash": (lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x),
|
| 19 |
+
"torch": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
|
| 20 |
+
"vanilla": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def attention(
|
| 25 |
+
q: Tensor,
|
| 26 |
+
k: Tensor,
|
| 27 |
+
v: Tensor,
|
| 28 |
+
mode: str = "flash",
|
| 29 |
+
drop_rate: float = 0.0,
|
| 30 |
+
attn_mask: Optional[Tensor] = None,
|
| 31 |
+
causal: bool = False,
|
| 32 |
+
cu_seqlens_q: Optional[Tensor] = None,
|
| 33 |
+
cu_seqlens_kv: Optional[Tensor] = None,
|
| 34 |
+
max_seqlen_q: Optional[int] = None,
|
| 35 |
+
max_seqlen_kv: Optional[int] = None,
|
| 36 |
+
batch_size: int = 1,
|
| 37 |
+
training: bool = True,
|
| 38 |
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
| 39 |
+
"""Perform QKV self attention.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
q (Tensor): Query tensor with shape [b, s, h, d], where h is the number of heads.
|
| 43 |
+
k (Tensor): Key tensor with shape [b, s1, h, d]
|
| 44 |
+
v (Tensor): Value tensor with shape [b, s1, h, d]
|
| 45 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
| 46 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
| 47 |
+
attn_mask (Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, h, s, s1] (torch or vanilla).
|
| 48 |
+
(default: None)
|
| 49 |
+
causal (bool): Whether to use causal attention. (default: False)
|
| 50 |
+
cu_seqlens_q (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 51 |
+
used to index into q.
|
| 52 |
+
cu_seqlens_kv (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 53 |
+
used to index into kv.
|
| 54 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
| 55 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor: Output tensor after self attention with shape [b, s, hd]
|
| 59 |
+
"""
|
| 60 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
| 61 |
+
q = pre_attn_layout(q)
|
| 62 |
+
k = pre_attn_layout(k)
|
| 63 |
+
v = pre_attn_layout(v)
|
| 64 |
+
|
| 65 |
+
if mode == "torch":
|
| 66 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
| 67 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 68 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
| 69 |
+
elif mode == "flash":
|
| 70 |
+
assert flash_attn_varlen_func is not None, "flash_attn is not installed or not supported"
|
| 71 |
+
x = flash_attn_varlen_func(
|
| 72 |
+
q,
|
| 73 |
+
k,
|
| 74 |
+
v,
|
| 75 |
+
cu_seqlens_q,
|
| 76 |
+
cu_seqlens_kv,
|
| 77 |
+
max_seqlen_q,
|
| 78 |
+
max_seqlen_kv,
|
| 79 |
+
)
|
| 80 |
+
# x with shape [(bxs), a, d]
|
| 81 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
| 82 |
+
elif mode == "vanilla":
|
| 83 |
+
scale_factor = 1.0 / math.sqrt(q.size(-1))
|
| 84 |
+
b, a, s_q, _ = q.shape
|
| 85 |
+
s_k = k.size(2)
|
| 86 |
+
attn_bias = torch.zeros(b, a, s_q, s_k, dtype=q.dtype, device=q.device)
|
| 87 |
+
if causal:
|
| 88 |
+
# Only applied to self attention
|
| 89 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
| 90 |
+
temp_mask = torch.ones(b, a, s_q, s_q, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
| 91 |
+
attn_bias.masked_fill_(~temp_mask, float("-inf"))
|
| 92 |
+
attn_bias = attn_bias.to(q.dtype)
|
| 93 |
+
if attn_mask is not None:
|
| 94 |
+
if attn_mask.dtype == torch.bool:
|
| 95 |
+
attn_bias.masked_fill_(~attn_mask, float("-inf"))
|
| 96 |
+
else:
|
| 97 |
+
attn_bias = attn_bias + attn_mask
|
| 98 |
+
|
| 99 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
| 100 |
+
attn = attn + attn_bias
|
| 101 |
+
attn = attn.softmax(dim=-1)
|
| 102 |
+
attn = torch.dropout(attn, p=drop_rate, train=training)
|
| 103 |
+
x = attn @ v
|
| 104 |
+
else:
|
| 105 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
| 106 |
+
|
| 107 |
+
x = post_attn_layout(x)
|
| 108 |
+
b, s, h, d = x.shape
|
| 109 |
+
out = x.reshape(b, s, -1)
|
| 110 |
+
return out
|
hymotion/network/bricks.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_activation_layer(act_type: str) -> Callable[[], nn.Module]:
|
| 9 |
+
if act_type == "gelu":
|
| 10 |
+
return lambda: nn.GELU()
|
| 11 |
+
elif act_type == "gelu_tanh":
|
| 12 |
+
return lambda: nn.GELU(approximate="tanh")
|
| 13 |
+
elif act_type == "relu":
|
| 14 |
+
return nn.ReLU
|
| 15 |
+
elif act_type == "silu":
|
| 16 |
+
return nn.SiLU
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_norm_layer(norm_type: Optional[str]):
|
| 22 |
+
if norm_type == "layer":
|
| 23 |
+
return nn.LayerNorm
|
| 24 |
+
elif norm_type == "rms":
|
| 25 |
+
return RMSNorm
|
| 26 |
+
elif norm_type == "none" or norm_type is None:
|
| 27 |
+
return nn.Identity
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unknown norm type: {norm_type}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RMSNorm(nn.Module):
|
| 33 |
+
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.eps = eps
|
| 36 |
+
if elementwise_affine:
|
| 37 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 38 |
+
|
| 39 |
+
def _norm(self, x: Tensor) -> Tensor:
|
| 40 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 43 |
+
output = self._norm(x.float()).type_as(x)
|
| 44 |
+
if hasattr(self, "weight"):
|
| 45 |
+
output = output * self.weight
|
| 46 |
+
return output
|
hymotion/network/encoders.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from ..utils.misc import to_2tuple
|
| 9 |
+
from .bricks import get_activation_layer, get_norm_layer
|
| 10 |
+
from .modulate_layers import ModulateDiT, modulate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_dim: int,
|
| 17 |
+
feat_dim: int,
|
| 18 |
+
out_dim: Optional[int] = None,
|
| 19 |
+
act_type: str = "gelu",
|
| 20 |
+
norm_type: Optional[str] = None,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
drop: float = 0.0,
|
| 23 |
+
use_conv: bool = False,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
out_dim = out_dim or in_dim
|
| 27 |
+
feat_dim = feat_dim or in_dim
|
| 28 |
+
bias = to_2tuple(bias)
|
| 29 |
+
drop_probs = to_2tuple(drop)
|
| 30 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 31 |
+
|
| 32 |
+
self.fc1 = linear_layer(in_dim, feat_dim, bias=bias[0] if isinstance(bias, (list, tuple)) else bias)
|
| 33 |
+
self.act = get_activation_layer(act_type)()
|
| 34 |
+
self.drop1 = nn.Dropout(drop_probs[0] if isinstance(drop_probs, (list, tuple)) else drop_probs)
|
| 35 |
+
self.norm = get_norm_layer(norm_type)(feat_dim) if norm_type else nn.Identity()
|
| 36 |
+
self.fc2 = linear_layer(feat_dim, out_dim, bias=bias[1] if isinstance(bias, (list, tuple)) else bias)
|
| 37 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 38 |
+
|
| 39 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 40 |
+
x = self.fc1(x)
|
| 41 |
+
x = self.act(x)
|
| 42 |
+
x = self.drop1(x)
|
| 43 |
+
x = self.norm(x)
|
| 44 |
+
x = self.fc2(x)
|
| 45 |
+
x = self.drop2(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MLPEncoder(nn.Module):
|
| 50 |
+
def __init__(self, in_dim: int, feat_dim: int, num_layers: int, act_type: str = "silu") -> None:
|
| 51 |
+
super(MLPEncoder, self).__init__()
|
| 52 |
+
self.in_dim = in_dim
|
| 53 |
+
self.feat_dim = feat_dim
|
| 54 |
+
linears = []
|
| 55 |
+
linears.append(nn.Linear(in_features=in_dim, out_features=self.feat_dim))
|
| 56 |
+
for i in range(num_layers - 1):
|
| 57 |
+
linears.append(get_activation_layer(act_type)())
|
| 58 |
+
linears.append(nn.Linear(self.feat_dim, self.feat_dim))
|
| 59 |
+
self.linears = nn.Sequential(*linears)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 62 |
+
return self.linears(x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class FinalLayer(nn.Module):
|
| 66 |
+
def __init__(self, feat_dim: int, out_dim: int, act_type: str = "gelu", zero_init=False, **kwargs):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.norm_final = nn.LayerNorm(feat_dim, elementwise_affine=False, eps=1e-6)
|
| 70 |
+
self.adaLN_modulation = ModulateDiT(feat_dim, factor=2, act_type=act_type)
|
| 71 |
+
self.linear = nn.Linear(feat_dim, out_dim, bias=True)
|
| 72 |
+
if zero_init:
|
| 73 |
+
nn.init.zeros_(self.linear.weight)
|
| 74 |
+
nn.init.zeros_(self.linear.bias)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: Tensor, adapter: Tensor) -> Tensor:
|
| 77 |
+
shift, scale = self.adaLN_modulation(adapter).chunk(2, dim=-1)
|
| 78 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
| 79 |
+
x = self.linear(x)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TimestepEmbeddingEncoder(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
embedding_dim: int,
|
| 87 |
+
feat_dim: int,
|
| 88 |
+
act_type: str = "silu",
|
| 89 |
+
time_factor: float = 1.0,
|
| 90 |
+
) -> None:
|
| 91 |
+
super(TimestepEmbeddingEncoder, self).__init__()
|
| 92 |
+
|
| 93 |
+
self.embedding_dim = embedding_dim
|
| 94 |
+
self.feat_dim = feat_dim
|
| 95 |
+
self.time_factor = time_factor
|
| 96 |
+
blocks = [
|
| 97 |
+
nn.Linear(embedding_dim, self.feat_dim),
|
| 98 |
+
get_activation_layer(act_type)(),
|
| 99 |
+
nn.Linear(self.feat_dim, self.feat_dim),
|
| 100 |
+
]
|
| 101 |
+
self.blocks = nn.Sequential(*blocks)
|
| 102 |
+
|
| 103 |
+
def forward(self, t: Tensor) -> Tensor:
|
| 104 |
+
x = self.blocks(self.sinusodial_embedding(t, self.embedding_dim, time_factor=self.time_factor)).unsqueeze(1)
|
| 105 |
+
return x
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def sinusodial_embedding(
|
| 109 |
+
timesteps: Tensor, embedding_dim: int, temperature: float = 10000.0, time_factor: float = 1.0
|
| 110 |
+
) -> Tensor:
|
| 111 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 112 |
+
timesteps = timesteps * time_factor
|
| 113 |
+
half = embedding_dim // 2
|
| 114 |
+
freqs = torch.exp(
|
| 115 |
+
-torch.log(torch.tensor(temperature)) * torch.arange(start=0, end=half, dtype=torch.float) / half
|
| 116 |
+
).to(device=timesteps.device)
|
| 117 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 118 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 119 |
+
if embedding_dim % 2:
|
| 120 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 121 |
+
return embedding
|
hymotion/network/hymotion_mmdit.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from ..utils.loaders import load_object
|
| 12 |
+
from ..utils.type_converter import get_module_device
|
| 13 |
+
from .attention import attention
|
| 14 |
+
from .bricks import get_activation_layer, get_norm_layer
|
| 15 |
+
from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
|
| 16 |
+
from .modulate_layers import ModulateDiT, apply_gate, modulate
|
| 17 |
+
from .positional_encoding import RotaryEmbedding
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MMBaseBlock(nn.Module):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
feat_dim: int,
|
| 24 |
+
num_heads: int,
|
| 25 |
+
mlp_ratio: float,
|
| 26 |
+
dropout: float,
|
| 27 |
+
positional_encoding_cfg: dict,
|
| 28 |
+
apply_rope_to_single_branch: bool,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.feat_dim = feat_dim
|
| 32 |
+
self.num_heads = num_heads
|
| 33 |
+
self.mlp_ratio = mlp_ratio
|
| 34 |
+
self.dropout = dropout
|
| 35 |
+
|
| 36 |
+
assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
|
| 37 |
+
self.head_dim = self.feat_dim // num_heads
|
| 38 |
+
|
| 39 |
+
self.mlp_hidden_dim = int(self.feat_dim * mlp_ratio)
|
| 40 |
+
|
| 41 |
+
self._positional_encoding_cfg = positional_encoding_cfg.copy()
|
| 42 |
+
self.rotary_emb = RotaryEmbedding(num_feats=self.head_dim, **self._positional_encoding_cfg)
|
| 43 |
+
self.apply_rope_to_single_branch = apply_rope_to_single_branch
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MMDoubleStreamBlock(MMBaseBlock):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
feat_dim: int,
|
| 50 |
+
num_heads: int,
|
| 51 |
+
mlp_ratio: float,
|
| 52 |
+
dropout: float,
|
| 53 |
+
mlp_act_type: str,
|
| 54 |
+
qk_norm_type: Optional[str] = None,
|
| 55 |
+
qkv_bias: bool = False,
|
| 56 |
+
positional_encoding_cfg: dict = {
|
| 57 |
+
"max_seq_len": 5000,
|
| 58 |
+
"use_real": True,
|
| 59 |
+
},
|
| 60 |
+
apply_rope_to_single_branch: bool = True,
|
| 61 |
+
):
|
| 62 |
+
super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
|
| 63 |
+
|
| 64 |
+
self.motion_mod = ModulateDiT(
|
| 65 |
+
self.feat_dim,
|
| 66 |
+
factor=6,
|
| 67 |
+
act_type="silu",
|
| 68 |
+
)
|
| 69 |
+
self.motion_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
|
| 70 |
+
|
| 71 |
+
motion_qkv_out_dim = self.feat_dim * 3
|
| 72 |
+
self.motion_qkv = nn.Linear(self.feat_dim, motion_qkv_out_dim, bias=qkv_bias)
|
| 73 |
+
|
| 74 |
+
self.motion_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 75 |
+
self.motion_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 76 |
+
self.motion_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
|
| 77 |
+
self.motion_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
|
| 78 |
+
self.motion_mlp = MLP(
|
| 79 |
+
self.feat_dim,
|
| 80 |
+
self.mlp_hidden_dim,
|
| 81 |
+
act_type=mlp_act_type,
|
| 82 |
+
bias=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.text_mod = ModulateDiT(
|
| 86 |
+
self.feat_dim,
|
| 87 |
+
factor=6,
|
| 88 |
+
act_type="silu",
|
| 89 |
+
)
|
| 90 |
+
self.text_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
|
| 91 |
+
|
| 92 |
+
text_qkv_out_dim = self.feat_dim * 3
|
| 93 |
+
self.text_qkv = nn.Linear(self.feat_dim, text_qkv_out_dim, bias=qkv_bias)
|
| 94 |
+
|
| 95 |
+
self.text_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 96 |
+
self.text_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 97 |
+
self.text_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
|
| 98 |
+
self.text_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
|
| 99 |
+
self.text_mlp = MLP(
|
| 100 |
+
self.feat_dim,
|
| 101 |
+
self.mlp_hidden_dim,
|
| 102 |
+
act_type=mlp_act_type,
|
| 103 |
+
bias=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def forward(
|
| 107 |
+
self,
|
| 108 |
+
motion_feat: Tensor,
|
| 109 |
+
text_feat: Tensor,
|
| 110 |
+
adapter: Tensor,
|
| 111 |
+
attn_mask: Optional[Tensor] = None,
|
| 112 |
+
) -> Tuple[Tensor, Tensor]:
|
| 113 |
+
(
|
| 114 |
+
motion_shift_msa,
|
| 115 |
+
motion_scale_msa,
|
| 116 |
+
motion_gate_msa,
|
| 117 |
+
motion_shift_mlp,
|
| 118 |
+
motion_scale_mlp,
|
| 119 |
+
motion_gate_mlp,
|
| 120 |
+
) = self.motion_mod(adapter).chunk(6, dim=-1)
|
| 121 |
+
(
|
| 122 |
+
text_shift_msa,
|
| 123 |
+
text_scale_msa,
|
| 124 |
+
text_gate_msa,
|
| 125 |
+
text_shift_mlp,
|
| 126 |
+
text_scale_mlp,
|
| 127 |
+
text_gate_mlp,
|
| 128 |
+
) = self.text_mod(
|
| 129 |
+
adapter
|
| 130 |
+
).chunk(6, dim=-1)
|
| 131 |
+
|
| 132 |
+
motion_modulated = self.motion_norm1(motion_feat)
|
| 133 |
+
motion_modulated = modulate(motion_modulated, shift=motion_shift_msa, scale=motion_scale_msa)
|
| 134 |
+
motion_qkv = self.motion_qkv(motion_modulated)
|
| 135 |
+
|
| 136 |
+
motion_q, motion_k, motion_v = rearrange(motion_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
| 137 |
+
motion_q = self.motion_q_norm(motion_q).to(motion_v)
|
| 138 |
+
motion_k = self.motion_k_norm(motion_k).to(motion_v)
|
| 139 |
+
|
| 140 |
+
if self.apply_rope_to_single_branch:
|
| 141 |
+
# NOTE: we don't apply RoPE to text_branch_two here
|
| 142 |
+
motion_q, motion_k = self.rotary_emb.apply_rotary_emb(motion_q, motion_k)
|
| 143 |
+
|
| 144 |
+
text_modulated = self.text_norm1(text_feat)
|
| 145 |
+
text_modulated = modulate(text_modulated, shift=text_shift_msa, scale=text_scale_msa)
|
| 146 |
+
text_qkv = self.text_qkv(text_modulated)
|
| 147 |
+
|
| 148 |
+
text_q, text_k, text_v = rearrange(
|
| 149 |
+
text_qkv,
|
| 150 |
+
"B L (K H D) -> K B L H D",
|
| 151 |
+
K=3,
|
| 152 |
+
H=self.num_heads,
|
| 153 |
+
)
|
| 154 |
+
text_q = self.text_q_norm(text_q).to(text_v)
|
| 155 |
+
text_k = self.text_k_norm(text_k).to(text_v)
|
| 156 |
+
|
| 157 |
+
q = torch.cat((motion_q, text_q), dim=1)
|
| 158 |
+
k = torch.cat((motion_k, text_k), dim=1)
|
| 159 |
+
v = torch.cat((motion_v, text_v), dim=1)
|
| 160 |
+
|
| 161 |
+
if not self.apply_rope_to_single_branch:
|
| 162 |
+
q, k = self.rotary_emb.apply_rotary_emb(q, k)
|
| 163 |
+
|
| 164 |
+
bsz, total_len, _, _ = q.shape
|
| 165 |
+
motion_len = motion_feat.shape[1]
|
| 166 |
+
text_len = text_feat.shape[1]
|
| 167 |
+
dropout_p = 0.0 if not self.training else self.dropout
|
| 168 |
+
|
| 169 |
+
attn_output = attention(
|
| 170 |
+
q,
|
| 171 |
+
k,
|
| 172 |
+
v,
|
| 173 |
+
mode="torch",
|
| 174 |
+
drop_rate=dropout_p,
|
| 175 |
+
attn_mask=attn_mask,
|
| 176 |
+
causal=False,
|
| 177 |
+
cu_seqlens_q=None,
|
| 178 |
+
cu_seqlens_kv=None,
|
| 179 |
+
max_seqlen_q=None,
|
| 180 |
+
max_seqlen_kv=None,
|
| 181 |
+
batch_size=bsz,
|
| 182 |
+
training=self.training,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
motion_attn_output, text_attn_output = (
|
| 186 |
+
attn_output[:, :motion_len, ...],
|
| 187 |
+
attn_output[:, motion_len:, ...],
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
motion_feat = motion_feat + apply_gate(self.motion_out_proj(motion_attn_output), gate=motion_gate_msa)
|
| 191 |
+
motion_feat = motion_feat + apply_gate(
|
| 192 |
+
self.motion_mlp(
|
| 193 |
+
modulate(
|
| 194 |
+
self.motion_norm2(motion_feat),
|
| 195 |
+
shift=motion_shift_mlp,
|
| 196 |
+
scale=motion_scale_mlp,
|
| 197 |
+
)
|
| 198 |
+
),
|
| 199 |
+
gate=motion_gate_mlp,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
text_feat = text_feat + apply_gate(self.text_out_proj(text_attn_output), gate=text_gate_msa)
|
| 203 |
+
text_feat = text_feat + apply_gate(
|
| 204 |
+
self.text_mlp(
|
| 205 |
+
modulate(
|
| 206 |
+
self.text_norm2(text_feat),
|
| 207 |
+
shift=text_shift_mlp,
|
| 208 |
+
scale=text_scale_mlp,
|
| 209 |
+
)
|
| 210 |
+
),
|
| 211 |
+
gate=text_gate_mlp,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return motion_feat, text_feat
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MMSingleStreamBlock(MMBaseBlock):
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
feat_dim: int,
|
| 221 |
+
num_heads: int,
|
| 222 |
+
mlp_ratio: float,
|
| 223 |
+
dropout: float,
|
| 224 |
+
mlp_act_type: str,
|
| 225 |
+
qk_norm_type: Optional[str] = None,
|
| 226 |
+
qkv_bias: bool = False,
|
| 227 |
+
positional_encoding_cfg: dict = {
|
| 228 |
+
"max_seq_len": 5000,
|
| 229 |
+
"use_real": True,
|
| 230 |
+
},
|
| 231 |
+
apply_rope_to_single_branch: bool = True,
|
| 232 |
+
):
|
| 233 |
+
super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
|
| 234 |
+
|
| 235 |
+
self.modulation = ModulateDiT(self.feat_dim, factor=3, act_type="silu")
|
| 236 |
+
self.norm = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
|
| 237 |
+
|
| 238 |
+
# qkv and mlp_in
|
| 239 |
+
qkv_factor = 3
|
| 240 |
+
self.linear1 = nn.Linear(self.feat_dim, self.feat_dim * qkv_factor + self.mlp_hidden_dim, bias=qkv_bias)
|
| 241 |
+
# proj and mlp_out
|
| 242 |
+
self.linear2 = nn.Linear(self.feat_dim + self.mlp_hidden_dim, self.feat_dim, bias=qkv_bias)
|
| 243 |
+
|
| 244 |
+
self.q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 245 |
+
self.k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 246 |
+
|
| 247 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
x: Tensor,
|
| 252 |
+
split_len: int,
|
| 253 |
+
adapter: Tensor,
|
| 254 |
+
attn_mask: Optional[Tensor] = None,
|
| 255 |
+
) -> Tensor:
|
| 256 |
+
(
|
| 257 |
+
shift_msa,
|
| 258 |
+
scale_msa,
|
| 259 |
+
gate_msa,
|
| 260 |
+
) = self.modulation(
|
| 261 |
+
adapter
|
| 262 |
+
).chunk(3, dim=-1)
|
| 263 |
+
x_modulated = modulate(self.norm(x), shift_msa, scale_msa)
|
| 264 |
+
|
| 265 |
+
qkv, mlp_hidden = torch.split(self.linear1(x_modulated), [3 * self.feat_dim, self.mlp_hidden_dim], dim=-1)
|
| 266 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
| 267 |
+
|
| 268 |
+
q = self.q_norm(q).to(v)
|
| 269 |
+
k = self.k_norm(k).to(v)
|
| 270 |
+
|
| 271 |
+
q1, q2 = q[:, :split_len, ...], q[:, split_len:, ...]
|
| 272 |
+
k1, k2 = k[:, :split_len, ...], k[:, split_len:, ...]
|
| 273 |
+
# apply rotary position embedding
|
| 274 |
+
if self.apply_rope_to_single_branch:
|
| 275 |
+
q1, k1 = self.rotary_emb.apply_rotary_emb(q1, k1)
|
| 276 |
+
q = torch.cat((q1, q2), dim=1)
|
| 277 |
+
k = torch.cat((k1, k2), dim=1)
|
| 278 |
+
if not self.apply_rope_to_single_branch:
|
| 279 |
+
q, k = self.rotary_emb.apply_rotary_emb(q, k)
|
| 280 |
+
|
| 281 |
+
bsz, total_len = x_modulated.shape[:2]
|
| 282 |
+
dropout_p = 0.0 if not self.training else self.dropout
|
| 283 |
+
|
| 284 |
+
attn_output = attention(
|
| 285 |
+
q,
|
| 286 |
+
k,
|
| 287 |
+
v,
|
| 288 |
+
mode="torch",
|
| 289 |
+
drop_rate=dropout_p,
|
| 290 |
+
attn_mask=attn_mask,
|
| 291 |
+
causal=False,
|
| 292 |
+
cu_seqlens_q=None,
|
| 293 |
+
cu_seqlens_kv=None,
|
| 294 |
+
max_seqlen_q=None,
|
| 295 |
+
max_seqlen_kv=None,
|
| 296 |
+
batch_size=bsz,
|
| 297 |
+
training=self.training,
|
| 298 |
+
)
|
| 299 |
+
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp_hidden)), 2))
|
| 300 |
+
|
| 301 |
+
return x + apply_gate(output, gate=gate_msa)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class HunyuanMotionMMDiT(nn.Module):
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
input_dim: int,
|
| 308 |
+
feat_dim: int,
|
| 309 |
+
output_dim: Optional[int] = None,
|
| 310 |
+
ctxt_input_dim: int = 4096,
|
| 311 |
+
vtxt_input_dim: int = 256,
|
| 312 |
+
text_refiner_module: str = "hymotion/network/token_refiner.SingleTokenRefiner",
|
| 313 |
+
text_refiner_cfg: dict = {
|
| 314 |
+
"num_layers": 2,
|
| 315 |
+
},
|
| 316 |
+
num_layers: int = 12,
|
| 317 |
+
num_heads: int = 16,
|
| 318 |
+
mlp_ratio: float = 4.0,
|
| 319 |
+
mlp_act_type: str = "gelu_tanh",
|
| 320 |
+
norm_type: str = "layer",
|
| 321 |
+
qk_norm_type: str = "rms",
|
| 322 |
+
qkv_bias: bool = True,
|
| 323 |
+
dropout: float = 0.0,
|
| 324 |
+
final_layer_module: str = "hymotion/network/encoders.FinalLayer",
|
| 325 |
+
final_layer_cfg: dict = {
|
| 326 |
+
"act_type": "silu",
|
| 327 |
+
},
|
| 328 |
+
mask_mode: Optional[str] = None,
|
| 329 |
+
apply_rope_to_single_branch: bool = True,
|
| 330 |
+
insert_start_token: bool = False,
|
| 331 |
+
with_long_skip_connection: bool = False,
|
| 332 |
+
time_factor: float = 1.0,
|
| 333 |
+
narrowband_length: float = 2.0,
|
| 334 |
+
**kwargs,
|
| 335 |
+
):
|
| 336 |
+
super().__init__()
|
| 337 |
+
self.motion_input_dim = input_dim
|
| 338 |
+
self.ctxt_input_dim = ctxt_input_dim
|
| 339 |
+
self.vtxt_input_dim = vtxt_input_dim
|
| 340 |
+
self.feat_dim = feat_dim
|
| 341 |
+
self.output_dim = output_dim or input_dim
|
| 342 |
+
self.mask_mode = mask_mode
|
| 343 |
+
self.insert_start_token = insert_start_token
|
| 344 |
+
self.time_factor = time_factor
|
| 345 |
+
self.narrowband_length = narrowband_length * 30.0
|
| 346 |
+
if self.insert_start_token:
|
| 347 |
+
self.start_token = nn.Parameter(torch.randn(1, feat_dim))
|
| 348 |
+
self.with_long_skip_connection = with_long_skip_connection
|
| 349 |
+
if self.with_long_skip_connection:
|
| 350 |
+
from .encoders import FinalLayer
|
| 351 |
+
|
| 352 |
+
self.long_skip_net = FinalLayer(feat_dim=feat_dim, out_dim=feat_dim, act_type="silu")
|
| 353 |
+
|
| 354 |
+
self.input_encoder = nn.Linear(in_features=input_dim, out_features=feat_dim)
|
| 355 |
+
self.ctxt_encoder = nn.Linear(in_features=ctxt_input_dim, out_features=feat_dim)
|
| 356 |
+
self.vtxt_encoder = MLPEncoder(in_dim=vtxt_input_dim, feat_dim=feat_dim, num_layers=2, act_type="silu")
|
| 357 |
+
self.timestep_encoder = TimestepEmbeddingEncoder(
|
| 358 |
+
embedding_dim=feat_dim,
|
| 359 |
+
feat_dim=feat_dim,
|
| 360 |
+
time_factor=time_factor,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if text_refiner_module != "" and text_refiner_module is not None:
|
| 364 |
+
text_refiner_cfg.update(input_dim=feat_dim, feat_dim=feat_dim, num_heads=num_heads)
|
| 365 |
+
self._text_refiner_cfg = text_refiner_cfg.copy()
|
| 366 |
+
self.text_refiner = load_object(text_refiner_module, text_refiner_cfg)
|
| 367 |
+
|
| 368 |
+
self.num_layers = num_layers
|
| 369 |
+
assert num_layers % 3 == 0, f"num_layers must be divisible by 3, but got {num_layers}"
|
| 370 |
+
self.mm_double_blocks_layers = int(num_layers // 3)
|
| 371 |
+
self.mm_single_blocks_layers = int(num_layers - num_layers // 3)
|
| 372 |
+
|
| 373 |
+
self.double_blocks = nn.ModuleList(
|
| 374 |
+
[
|
| 375 |
+
MMDoubleStreamBlock(
|
| 376 |
+
feat_dim=feat_dim,
|
| 377 |
+
num_heads=num_heads,
|
| 378 |
+
mlp_ratio=mlp_ratio,
|
| 379 |
+
dropout=dropout,
|
| 380 |
+
mlp_act_type=mlp_act_type,
|
| 381 |
+
qk_norm_type=qk_norm_type,
|
| 382 |
+
qkv_bias=qkv_bias,
|
| 383 |
+
apply_rope_to_single_branch=apply_rope_to_single_branch,
|
| 384 |
+
)
|
| 385 |
+
for _ in range(self.mm_double_blocks_layers)
|
| 386 |
+
]
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
self.single_blocks = nn.ModuleList(
|
| 390 |
+
[
|
| 391 |
+
MMSingleStreamBlock(
|
| 392 |
+
feat_dim=feat_dim,
|
| 393 |
+
num_heads=num_heads,
|
| 394 |
+
mlp_ratio=mlp_ratio,
|
| 395 |
+
dropout=dropout,
|
| 396 |
+
mlp_act_type=mlp_act_type,
|
| 397 |
+
qk_norm_type=qk_norm_type,
|
| 398 |
+
qkv_bias=qkv_bias,
|
| 399 |
+
apply_rope_to_single_branch=apply_rope_to_single_branch,
|
| 400 |
+
)
|
| 401 |
+
for _ in range(self.mm_single_blocks_layers)
|
| 402 |
+
]
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
final_layer_cfg.update(feat_dim=feat_dim, out_dim=self.output_dim)
|
| 406 |
+
self._final_layer_cfg = final_layer_cfg.copy()
|
| 407 |
+
self.final_layer = load_object(final_layer_module, final_layer_cfg)
|
| 408 |
+
|
| 409 |
+
def forward(
|
| 410 |
+
self,
|
| 411 |
+
x: Tensor,
|
| 412 |
+
ctxt_input: Tensor,
|
| 413 |
+
vtxt_input: Tensor,
|
| 414 |
+
timesteps: Tensor,
|
| 415 |
+
x_mask_temporal: Tensor,
|
| 416 |
+
ctxt_mask_temporal: Tensor,
|
| 417 |
+
**kwargs,
|
| 418 |
+
) -> Tensor:
|
| 419 |
+
device = get_module_device(self)
|
| 420 |
+
|
| 421 |
+
motion_feat = self.input_encoder(x)
|
| 422 |
+
if self.with_long_skip_connection:
|
| 423 |
+
origin_feat = motion_feat
|
| 424 |
+
if self.insert_start_token:
|
| 425 |
+
# (B, 1, D) + (B, L, D) -> (B, L+1, D)
|
| 426 |
+
start_token = self.start_token[None].repeat(motion_feat.shape[0], 1, 1)
|
| 427 |
+
motion_feat = torch.cat((start_token, motion_feat), dim=1)
|
| 428 |
+
x_mask_temporal = torch.cat(
|
| 429 |
+
[
|
| 430 |
+
torch.ones_like(x_mask_temporal[:, :1], dtype=torch.bool),
|
| 431 |
+
x_mask_temporal,
|
| 432 |
+
],
|
| 433 |
+
dim=1,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
timestep_feat = self.timestep_encoder(timesteps)
|
| 437 |
+
vtxt_feat = self.vtxt_encoder(vtxt_input.float())
|
| 438 |
+
adapter = timestep_feat + vtxt_feat
|
| 439 |
+
|
| 440 |
+
motion_key_padding_mask = self._canonical_mask(x_mask_temporal).to(device)
|
| 441 |
+
ctxt_key_padding_mask = self._canonical_mask(ctxt_mask_temporal).to(device)
|
| 442 |
+
seq_key_padding_mask = torch.cat((motion_key_padding_mask, ctxt_key_padding_mask), dim=1)
|
| 443 |
+
if self.mask_mode is None:
|
| 444 |
+
seq_mask = None
|
| 445 |
+
elif self.mask_mode == "causal":
|
| 446 |
+
motion_len = motion_feat.shape[1]
|
| 447 |
+
seq_mask = torch.triu(
|
| 448 |
+
torch.full((motion_len, motion_len), float("-inf"), device=device),
|
| 449 |
+
diagonal=1,
|
| 450 |
+
)
|
| 451 |
+
elif self.mask_mode == "narrowband":
|
| 452 |
+
window = int(round(self.narrowband_length))
|
| 453 |
+
motion_len = motion_feat.shape[1]
|
| 454 |
+
idx = torch.arange(motion_len, device=device)
|
| 455 |
+
dist = (idx[None, :] - idx[:, None]).abs()
|
| 456 |
+
band = dist <= window
|
| 457 |
+
seq_mask = torch.full((motion_len, motion_len), float("-inf"), device=device)
|
| 458 |
+
seq_mask = seq_mask.masked_fill(band, 0.0)
|
| 459 |
+
else:
|
| 460 |
+
raise ValueError(f"Unsupported mask mode: {self.mask_mode}")
|
| 461 |
+
|
| 462 |
+
ctxt_feat = self.ctxt_encoder(ctxt_input.float())
|
| 463 |
+
if hasattr(self, "text_refiner"):
|
| 464 |
+
ctxt_feat = self.text_refiner(x=ctxt_feat, t=timesteps, mask=(ctxt_key_padding_mask == 0).to(device))
|
| 465 |
+
|
| 466 |
+
# precompute shared attention masks (broadcastable over heads)
|
| 467 |
+
bsz = x.shape[0]
|
| 468 |
+
motion_len = motion_feat.shape[1]
|
| 469 |
+
text_len = ctxt_feat.shape[1]
|
| 470 |
+
total_len = motion_len + text_len
|
| 471 |
+
mask_dtype = motion_feat.dtype
|
| 472 |
+
attn_mask_double = self._build_dmm_attn_mask_shared(
|
| 473 |
+
bsz=bsz,
|
| 474 |
+
motion_len=motion_len,
|
| 475 |
+
text_len=text_len,
|
| 476 |
+
dtype=mask_dtype,
|
| 477 |
+
key_padding_mask=seq_key_padding_mask,
|
| 478 |
+
attn_mask=seq_mask,
|
| 479 |
+
device=device,
|
| 480 |
+
)
|
| 481 |
+
for i_layer, mod in enumerate(self.double_blocks):
|
| 482 |
+
motion_feat, ctxt_feat = mod(
|
| 483 |
+
motion_feat=motion_feat,
|
| 484 |
+
text_feat=ctxt_feat,
|
| 485 |
+
adapter=adapter,
|
| 486 |
+
attn_mask=attn_mask_double,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# precompute shared attention masks for single stream blocks too
|
| 490 |
+
split_len = motion_feat.shape[1]
|
| 491 |
+
x = torch.cat((motion_feat, ctxt_feat), 1)
|
| 492 |
+
attn_mask_single = self._build_smm_attn_mask_shared(
|
| 493 |
+
bsz=bsz,
|
| 494 |
+
split_len=split_len,
|
| 495 |
+
total_len=total_len,
|
| 496 |
+
dtype=mask_dtype,
|
| 497 |
+
key_padding_mask=seq_key_padding_mask,
|
| 498 |
+
attn_mask=seq_mask,
|
| 499 |
+
device=device,
|
| 500 |
+
)
|
| 501 |
+
for i_layer, mod in enumerate(self.single_blocks):
|
| 502 |
+
x = mod(
|
| 503 |
+
x=x,
|
| 504 |
+
split_len=split_len,
|
| 505 |
+
adapter=adapter,
|
| 506 |
+
attn_mask=attn_mask_single,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
x = x[:, :split_len, ...]
|
| 510 |
+
if self.insert_start_token:
|
| 511 |
+
x = x[:, 1:, ...]
|
| 512 |
+
|
| 513 |
+
if self.with_long_skip_connection:
|
| 514 |
+
# long skip only consider timestep_feat
|
| 515 |
+
x = self.long_skip_net(origin_feat, timestep_feat) + x
|
| 516 |
+
|
| 517 |
+
predicted_res = self.final_layer(x, adapter)
|
| 518 |
+
return predicted_res
|
| 519 |
+
|
| 520 |
+
@staticmethod
|
| 521 |
+
def _canonical_mask(input_mask: Tensor) -> Tensor:
|
| 522 |
+
if input_mask.ndim == 1:
|
| 523 |
+
input_mask = input_mask.unsqueeze(1)
|
| 524 |
+
key_padding_mask = torch.where(
|
| 525 |
+
input_mask,
|
| 526 |
+
torch.zeros_like(input_mask, dtype=torch.float),
|
| 527 |
+
torch.full_like(input_mask, float("-inf"), dtype=torch.float),
|
| 528 |
+
)
|
| 529 |
+
return key_padding_mask
|
| 530 |
+
|
| 531 |
+
def _build_dmm_attn_mask_shared(
|
| 532 |
+
self,
|
| 533 |
+
bsz: int,
|
| 534 |
+
motion_len: int,
|
| 535 |
+
text_len: int,
|
| 536 |
+
dtype: torch.dtype,
|
| 537 |
+
key_padding_mask: Optional[Tensor],
|
| 538 |
+
attn_mask: Optional[Tensor],
|
| 539 |
+
device: torch.device,
|
| 540 |
+
) -> Tensor:
|
| 541 |
+
"""
|
| 542 |
+
NOTE:
|
| 543 |
+
motion_k text_k
|
| 544 |
+
motion_q [M→M] [M→T]
|
| 545 |
+
text_q [T→M] [T→T]
|
| 546 |
+
only [M→M] contains given mask
|
| 547 |
+
"""
|
| 548 |
+
total_len = motion_len + text_len
|
| 549 |
+
base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
|
| 550 |
+
if attn_mask is not None:
|
| 551 |
+
if attn_mask.dim() != 2 or attn_mask.shape != (motion_len, motion_len):
|
| 552 |
+
raise RuntimeError(
|
| 553 |
+
f"attn_mask should be 2D with shape {(motion_len, motion_len)}, got {attn_mask.shape}"
|
| 554 |
+
)
|
| 555 |
+
base[:, :, :motion_len, :motion_len] += attn_mask.view(1, 1, motion_len, motion_len)
|
| 556 |
+
if key_padding_mask is not None:
|
| 557 |
+
mask_total_len = key_padding_mask.shape[1]
|
| 558 |
+
if mask_total_len == motion_len:
|
| 559 |
+
pad = torch.zeros((bsz, text_len), dtype=key_padding_mask.dtype, device=device)
|
| 560 |
+
key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
|
| 561 |
+
base = base + key_padding_mask.view(bsz, 1, 1, total_len)
|
| 562 |
+
# disable T→M
|
| 563 |
+
base[:, :, motion_len:, :motion_len] = float("-inf")
|
| 564 |
+
return base
|
| 565 |
+
|
| 566 |
+
def _build_smm_attn_mask_shared(
|
| 567 |
+
self,
|
| 568 |
+
bsz: int,
|
| 569 |
+
split_len: int,
|
| 570 |
+
total_len: int,
|
| 571 |
+
dtype: torch.dtype,
|
| 572 |
+
key_padding_mask: Optional[Tensor],
|
| 573 |
+
attn_mask: Optional[Tensor],
|
| 574 |
+
device: torch.device,
|
| 575 |
+
) -> Tensor:
|
| 576 |
+
"""
|
| 577 |
+
NOTE:
|
| 578 |
+
motion_k text_k
|
| 579 |
+
motion_q [M→M] [M→T]
|
| 580 |
+
text_q [T→M] [T→T]
|
| 581 |
+
only [M→M] contains given mask
|
| 582 |
+
"""
|
| 583 |
+
base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
|
| 584 |
+
if attn_mask is not None:
|
| 585 |
+
if attn_mask.dim() != 2 or attn_mask.shape != (split_len, split_len):
|
| 586 |
+
raise RuntimeError(f"attn_mask should be 2D with shape {(split_len, split_len)}, got {attn_mask.shape}")
|
| 587 |
+
base[:, :, :split_len, :split_len] += attn_mask.view(1, 1, split_len, split_len)
|
| 588 |
+
if key_padding_mask is not None:
|
| 589 |
+
mask_total_len = key_padding_mask.shape[1]
|
| 590 |
+
if mask_total_len == split_len:
|
| 591 |
+
pad = torch.zeros(
|
| 592 |
+
(bsz, total_len - split_len),
|
| 593 |
+
dtype=key_padding_mask.dtype,
|
| 594 |
+
device=device,
|
| 595 |
+
)
|
| 596 |
+
key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
|
| 597 |
+
base = base + key_padding_mask.view(bsz, 1, 1, total_len)
|
| 598 |
+
# disable T→M
|
| 599 |
+
base[:, :, split_len:, :split_len] = float("-inf")
|
| 600 |
+
return base
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
if __name__ == "__main__":
|
| 604 |
+
# python -m hymotion.network.hymotion_mmdit
|
| 605 |
+
|
| 606 |
+
from configs._base_.model_network_base import MOTION_MODEL_CONFIG # pyright: ignore
|
| 607 |
+
|
| 608 |
+
network_module_cfg = MOTION_MODEL_CONFIG["1.04B_narrowband"]["network_module_args"]
|
| 609 |
+
network_module_cfg = dict(network_module_cfg) # convert to normal dict
|
| 610 |
+
|
| 611 |
+
bsz, seq_len, text_seq_len, input_dim = 1, 360, 128, 201
|
| 612 |
+
network_module_cfg["input_dim"] = input_dim
|
| 613 |
+
MMDiT = HunyuanMotionMMDiT(**network_module_cfg)
|
| 614 |
+
|
| 615 |
+
x = torch.randn(bsz, seq_len, input_dim)
|
| 616 |
+
ctxt_condition = torch.randn(bsz, text_seq_len, 4096)
|
| 617 |
+
vtxt_condition = torch.randn(bsz, 1, 768)
|
| 618 |
+
timesteps = torch.randint(0, 1000, (bsz,))
|
| 619 |
+
length = torch.arange(seq_len).unsqueeze(0).repeat(bsz, 1)
|
| 620 |
+
ctxt_length = torch.arange(text_seq_len).unsqueeze(0).repeat(bsz, 1)
|
| 621 |
+
x_mask_temporal = length < 100
|
| 622 |
+
ctxt_mask_temporal = ctxt_length < 50
|
| 623 |
+
x = MMDiT(
|
| 624 |
+
x=x,
|
| 625 |
+
ctxt_input=ctxt_condition,
|
| 626 |
+
vtxt_input=vtxt_condition,
|
| 627 |
+
timesteps=timesteps,
|
| 628 |
+
x_mask_temporal=x_mask_temporal,
|
| 629 |
+
ctxt_mask_temporal=ctxt_mask_temporal,
|
| 630 |
+
)
|
| 631 |
+
assert x.shape == (
|
| 632 |
+
bsz,
|
| 633 |
+
seq_len,
|
| 634 |
+
input_dim,
|
| 635 |
+
), f"unexpected output shape: {x.shape}, which should be ({bsz}, {seq_len}, {input_dim})"
|
| 636 |
+
print(x.shape)
|
hymotion/network/modulate_layers.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
from .bricks import get_activation_layer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModulateDiT(nn.Module):
|
| 10 |
+
def __init__(self, feat_dim: int, factor: int, act_type: str = "silu"):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.act = get_activation_layer(act_type)()
|
| 13 |
+
self.linear = nn.Linear(feat_dim, factor * feat_dim, bias=True)
|
| 14 |
+
nn.init.zeros_(self.linear.weight)
|
| 15 |
+
nn.init.zeros_(self.linear.bias)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 18 |
+
return self.linear(self.act(x))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def modulate(x: Tensor, shift: Optional[Tensor] = None, scale: Optional[Tensor] = None) -> Tensor:
|
| 22 |
+
if shift is not None and scale is not None:
|
| 23 |
+
assert len(x.shape) == len(shift.shape) == len(scale.shape), (
|
| 24 |
+
"x, shift, scale must have the same number of dimensions, "
|
| 25 |
+
f"but got x.shape: {x.shape}, "
|
| 26 |
+
f"shift.shape: {shift.shape} "
|
| 27 |
+
f"and scale.shape: {scale.shape}"
|
| 28 |
+
)
|
| 29 |
+
if shift is not None and scale is not None:
|
| 30 |
+
return x * (1 + scale) + shift
|
| 31 |
+
elif shift is not None:
|
| 32 |
+
return x + shift
|
| 33 |
+
elif scale is not None:
|
| 34 |
+
return x * (1 + scale)
|
| 35 |
+
else:
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def apply_gate(x: Tensor, gate: Optional[Tensor] = None, tanh: bool = False) -> Tensor:
|
| 40 |
+
if gate is not None:
|
| 41 |
+
assert len(x.shape) == len(
|
| 42 |
+
gate.shape
|
| 43 |
+
), f"x, gate must have the same number of dimensions, but got {x.shape} and {gate.shape}"
|
| 44 |
+
if gate is None:
|
| 45 |
+
return x
|
| 46 |
+
if tanh:
|
| 47 |
+
return x * gate.tanh()
|
| 48 |
+
else:
|
| 49 |
+
return x * gate
|
hymotion/network/positional_encoding.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RotaryEmbedding(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
num_feats: int,
|
| 13 |
+
max_seq_len: Union[Tensor, int],
|
| 14 |
+
temperature: int = 10000,
|
| 15 |
+
use_real: bool = False,
|
| 16 |
+
theta_rescale_factor: float = 1.0,
|
| 17 |
+
interpolation_factor: float = 1.0,
|
| 18 |
+
) -> None:
|
| 19 |
+
super(RotaryEmbedding, self).__init__()
|
| 20 |
+
assert num_feats % 2 == 0, "num_feats (head_dim) must be even for RoPE."
|
| 21 |
+
self.num_feats = num_feats
|
| 22 |
+
self.max_seq_len = max_seq_len
|
| 23 |
+
self.temperature = temperature
|
| 24 |
+
self.use_real = use_real
|
| 25 |
+
self.theta_rescale_factor = theta_rescale_factor
|
| 26 |
+
self.interpolation_factor = interpolation_factor
|
| 27 |
+
|
| 28 |
+
if isinstance(max_seq_len, int):
|
| 29 |
+
max_seq_len = torch.arange(max_seq_len).float()
|
| 30 |
+
|
| 31 |
+
if theta_rescale_factor != 1.0:
|
| 32 |
+
temperature *= theta_rescale_factor ** (self.num_feats / (self.num_feats - 2))
|
| 33 |
+
dim_t = torch.arange(0, self.num_feats, 2, dtype=torch.float32)
|
| 34 |
+
freqs = 1.0 / (temperature ** (2 * torch.div(dim_t, 2, rounding_mode="trunc") / self.num_feats)) # [D/2]
|
| 35 |
+
freqs = torch.outer(max_seq_len.float() * interpolation_factor, freqs) # [S, D/2]
|
| 36 |
+
if use_real:
|
| 37 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
| 38 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
| 39 |
+
self.freqs_cis = (freqs_cos, freqs_sin)
|
| 40 |
+
else:
|
| 41 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
|
| 42 |
+
self.freqs_cis = freqs_cis
|
| 43 |
+
|
| 44 |
+
def reshape_for_broadcast(
|
| 45 |
+
self, freqs_cis: Union[Tensor, Tuple[Tensor, Tensor]], x: Tensor
|
| 46 |
+
) -> Union[Tuple[Tensor, Tensor], Tensor]:
|
| 47 |
+
ndim = x.ndim
|
| 48 |
+
assert 0 <= 1 < ndim
|
| 49 |
+
|
| 50 |
+
if isinstance(freqs_cis, tuple):
|
| 51 |
+
# freqs_cis: (cos, sin) in real space
|
| 52 |
+
assert (
|
| 53 |
+
freqs_cis[0].shape[-1] == x.shape[-1]
|
| 54 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape} on the head_dim dimension"
|
| 55 |
+
assert freqs_cis[0].shape[0] >= x.shape[1], (
|
| 56 |
+
f"freqs_cis shape {freqs_cis[0].shape} should be larger than or equal to "
|
| 57 |
+
f"x shape {x.shape} on the time dimension"
|
| 58 |
+
)
|
| 59 |
+
shape = []
|
| 60 |
+
for i, d in enumerate(x.shape):
|
| 61 |
+
if i == 1:
|
| 62 |
+
shape.append(-1)
|
| 63 |
+
elif i == ndim - 1:
|
| 64 |
+
shape.append(d)
|
| 65 |
+
else:
|
| 66 |
+
shape.append(1)
|
| 67 |
+
return (
|
| 68 |
+
freqs_cis[0].view(*shape)[:, : x.shape[1], ...],
|
| 69 |
+
freqs_cis[1].view(*shape)[:, : x.shape[1], ...],
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
# freqs_cis: values in complex space
|
| 73 |
+
assert (
|
| 74 |
+
freqs_cis.shape[-1] == x.shape[-1]
|
| 75 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape} on the head_dim dimension"
|
| 76 |
+
assert freqs_cis.shape[0] >= x.shape[1], (
|
| 77 |
+
f"freqs_cis shape {freqs_cis.shape} should be larger than or equal to "
|
| 78 |
+
f"x shape {x.shape} on the time dimension"
|
| 79 |
+
)
|
| 80 |
+
shape = []
|
| 81 |
+
for i, d in enumerate(x.shape):
|
| 82 |
+
if i == 1:
|
| 83 |
+
shape.append(-1)
|
| 84 |
+
elif i == ndim - 1:
|
| 85 |
+
shape.append(d)
|
| 86 |
+
else:
|
| 87 |
+
shape.append(1)
|
| 88 |
+
return freqs_cis.view(*shape)[:, : x.shape[1], ...]
|
| 89 |
+
|
| 90 |
+
def rotate_half(self, x: Tensor) -> Tensor:
|
| 91 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 92 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 93 |
+
|
| 94 |
+
def apply_rotary_emb(self, xq: Tensor, xk: Tensor) -> Tuple[Tensor, Tensor]:
|
| 95 |
+
xk_out = None
|
| 96 |
+
if isinstance(self.freqs_cis, tuple):
|
| 97 |
+
cos, sin = self.reshape_for_broadcast(self.freqs_cis, xq) # [B, L, H, D]
|
| 98 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
| 99 |
+
# real * cos - imag * sin
|
| 100 |
+
# imag * cos + real * sin
|
| 101 |
+
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
| 102 |
+
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
| 103 |
+
else:
|
| 104 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
| 105 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 106 |
+
freqs_cis = self.reshape_for_broadcast(self.freqs_cis, xq_)
|
| 107 |
+
# Handle device transfer based on return type
|
| 108 |
+
if isinstance(freqs_cis, tuple):
|
| 109 |
+
freqs_cis = (freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device))
|
| 110 |
+
else:
|
| 111 |
+
freqs_cis = freqs_cis.to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
| 112 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
| 113 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
| 114 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
| 115 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
| 116 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
| 117 |
+
|
| 118 |
+
return xq_out, xk_out
|
| 119 |
+
|
| 120 |
+
def __repr__(self) -> str:
|
| 121 |
+
repr_str = self.__class__.__name__
|
| 122 |
+
repr_str += f"(num_feats={self.num_feats}, "
|
| 123 |
+
repr_str += f"max_seq_len={self.max_seq_len}, "
|
| 124 |
+
repr_str += f"temperature={self.temperature}, "
|
| 125 |
+
repr_str += f"use_real={self.use_real}, "
|
| 126 |
+
repr_str += f"theta_rescale_factor={self.theta_rescale_factor}, "
|
| 127 |
+
repr_str += f"interpolation_factor={self.interpolation_factor})"
|
| 128 |
+
return repr_str
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class PositionalEncoding(nn.Module):
|
| 132 |
+
def __init__(self, num_feats: int, dropout: float = 0.1, max_len: int = 5000):
|
| 133 |
+
super(PositionalEncoding, self).__init__()
|
| 134 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 135 |
+
|
| 136 |
+
pe = torch.zeros(max_len, num_feats)
|
| 137 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 138 |
+
div_term = torch.exp(torch.arange(0, num_feats, 2).float() * (-np.log(10000.0) / num_feats))
|
| 139 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 140 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 141 |
+
pe = pe.unsqueeze(0) # shape of [1, L, D]
|
| 142 |
+
self.register_buffer("pe", pe)
|
| 143 |
+
|
| 144 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 145 |
+
x = x + self.pe[:, : x.shape[1], :] # shape of [B, L, D]
|
| 146 |
+
return self.dropout(x)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
# python -m hymotion.network.positional_encoding
|
| 151 |
+
num_feats = 32
|
| 152 |
+
rope = RotaryEmbedding(num_feats=num_feats, max_seq_len=5000, use_real=True)
|
| 153 |
+
x = torch.ones(1, 360, 1, num_feats)
|
| 154 |
+
text = torch.ones(1, 256, 1, num_feats)
|
| 155 |
+
q1, k1 = x.clone(), x.clone()
|
| 156 |
+
q2, k2 = text.clone(), text.clone()
|
| 157 |
+
print(x.shape)
|
| 158 |
+
# q1, k1 = rope.apply_rotary_emb(q1, k1)
|
| 159 |
+
# q2, k2 = rope.apply_rotary_emb(q2, k2)
|
| 160 |
+
q = torch.cat([q1, q2], dim=1)
|
| 161 |
+
k = torch.cat([k1, k2], dim=1)
|
| 162 |
+
q, k = rope.apply_rotary_emb(q, k)
|
| 163 |
+
q, k = q[0, :, 0, :], k[0, :, 0, :]
|
| 164 |
+
attn = (q[:, None] * k[None, :]).sum(dim=-1)
|
| 165 |
+
# softmax
|
| 166 |
+
# attn = torch.softmax(attn, dim=-1)
|
| 167 |
+
attn = attn.cpu().numpy()
|
| 168 |
+
|
| 169 |
+
import matplotlib.pyplot as plt
|
| 170 |
+
|
| 171 |
+
plt.imshow(attn, cmap="hot")
|
| 172 |
+
plt.colorbar()
|
| 173 |
+
plt.savefig("attn.png")
|
| 174 |
+
breakpoint()
|
hymotion/network/text_encoders/model_constants.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = [
|
| 2 |
+
"PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION",
|
| 3 |
+
]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION = """
|
| 7 |
+
Summarize human motion only from the user text for representation: action categories, key body-part movements, order/transitions, trajectory/direction, posture; include style/emotion/speed only if present. Explicitly capture laterality (left/right) when mentioned; do not guess. If multiple actions are described, indicate the count of distinct actions (e.g., actions=3) and their order. Do not invent missing info. Keep one concise paragraph.
|
| 8 |
+
"""
|
hymotion/network/text_encoders/text_encoder.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTokenizer,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from ...utils.type_converter import get_module_device
|
| 15 |
+
from .model_constants import PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION
|
| 16 |
+
|
| 17 |
+
USE_HF_MODELS = os.environ.get("USE_HF_MODELS", "0") == "1"
|
| 18 |
+
|
| 19 |
+
if USE_HF_MODELS:
|
| 20 |
+
QWEN_PATH = "Qwen/Qwen3-8B"
|
| 21 |
+
CLIP_PATH = "openai/clip-vit-large-patch14"
|
| 22 |
+
else:
|
| 23 |
+
QWEN_PATH = "ckpts/Qwen3-8B"
|
| 24 |
+
CLIP_PATH = "ckpts/clip-vit-large-patch14"
|
| 25 |
+
|
| 26 |
+
LLM_ENCODER_LAYOUT = {
|
| 27 |
+
"qwen3": {
|
| 28 |
+
"module_path": QWEN_PATH,
|
| 29 |
+
"template": [
|
| 30 |
+
{"role": "system", "content": f"{PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION}"},
|
| 31 |
+
{"role": "user", "content": "{}"},
|
| 32 |
+
],
|
| 33 |
+
"crop_start": 0,
|
| 34 |
+
"tokenizer_class": AutoTokenizer,
|
| 35 |
+
"text_encoder_class": AutoModelForCausalLM,
|
| 36 |
+
},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
SENTENCE_EMB_LAYOUT = {
|
| 40 |
+
"clipl": {
|
| 41 |
+
"module_path": CLIP_PATH,
|
| 42 |
+
"tokenizer_class": CLIPTokenizer,
|
| 43 |
+
"text_encoder_class": CLIPTextModel,
|
| 44 |
+
"pooling_mode": "pooler_output",
|
| 45 |
+
"max_length": 77,
|
| 46 |
+
},
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class HYTextModel(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
llm_type: Optional[str] = "qwen3",
|
| 54 |
+
max_length_llm: int = 512,
|
| 55 |
+
sentence_emb_type: Optional[str] = "clipl",
|
| 56 |
+
max_length_sentence_emb: int = 77,
|
| 57 |
+
enable_llm_padding: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.text_encoder_type = "hy_text_model"
|
| 61 |
+
|
| 62 |
+
self.sentence_emb_type = sentence_emb_type
|
| 63 |
+
self.sentence_emb_text_encoder = None
|
| 64 |
+
self.sentence_emb_tokenizer = None
|
| 65 |
+
self.vtxt_dim = 0
|
| 66 |
+
if sentence_emb_type is not None:
|
| 67 |
+
assert sentence_emb_type in SENTENCE_EMB_LAYOUT, f"Unsupported sentence embedding type: {sentence_emb_type}"
|
| 68 |
+
self.max_length_sentence_emb = max_length_sentence_emb or SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
|
| 69 |
+
"max_length", 77
|
| 70 |
+
)
|
| 71 |
+
self._sentence_emb_pooling_mode = SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
|
| 72 |
+
"pooling_mode", "pooler_output"
|
| 73 |
+
)
|
| 74 |
+
tokenizer_kwargs = SENTENCE_EMB_LAYOUT[sentence_emb_type].get("tokenizer_kwargs", {})
|
| 75 |
+
|
| 76 |
+
self.sentence_emb_tokenizer = SENTENCE_EMB_LAYOUT[sentence_emb_type]["tokenizer_class"].from_pretrained(
|
| 77 |
+
SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"],
|
| 78 |
+
max_length=self.max_length_sentence_emb,
|
| 79 |
+
**tokenizer_kwargs,
|
| 80 |
+
)
|
| 81 |
+
self.sentence_emb_text_encoder = SENTENCE_EMB_LAYOUT[sentence_emb_type][
|
| 82 |
+
"text_encoder_class"
|
| 83 |
+
].from_pretrained(SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"])
|
| 84 |
+
self.sentence_emb_text_encoder = self.sentence_emb_text_encoder.eval().requires_grad_(False)
|
| 85 |
+
self.vtxt_dim = self.sentence_emb_text_encoder.config.hidden_size
|
| 86 |
+
|
| 87 |
+
self.llm_type = llm_type
|
| 88 |
+
self.llm_text_encoder = None
|
| 89 |
+
self.llm_tokenizer = None
|
| 90 |
+
self.ctxt_dim = 0
|
| 91 |
+
self.crop_start = 0
|
| 92 |
+
self.max_length_llm = max_length_llm
|
| 93 |
+
if llm_type is not None:
|
| 94 |
+
assert llm_type in LLM_ENCODER_LAYOUT, f"Unsupported LLM type: {llm_type}"
|
| 95 |
+
self._orig_max_length_llm = max_length_llm
|
| 96 |
+
self.enable_llm_padding = enable_llm_padding
|
| 97 |
+
self.llm_tokenizer = LLM_ENCODER_LAYOUT[llm_type]["tokenizer_class"].from_pretrained(
|
| 98 |
+
LLM_ENCODER_LAYOUT[llm_type]["module_path"],
|
| 99 |
+
padding_side="right",
|
| 100 |
+
)
|
| 101 |
+
self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
|
| 102 |
+
LLM_ENCODER_LAYOUT[llm_type]["module_path"], low_cpu_mem_usage=True
|
| 103 |
+
)
|
| 104 |
+
self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
|
| 105 |
+
self.ctxt_dim = self.llm_text_encoder.config.hidden_size
|
| 106 |
+
|
| 107 |
+
self.crop_start = self._compute_crop_start()
|
| 108 |
+
self.max_length_llm = self._orig_max_length_llm + self.crop_start
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def encode_llm(self, text: List[str]) -> Tuple[Tensor, Tensor]:
|
| 112 |
+
if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
|
| 113 |
+
raise ValueError("LLM model not initialized")
|
| 114 |
+
|
| 115 |
+
device = get_module_device(self)
|
| 116 |
+
llm_text = [
|
| 117 |
+
(
|
| 118 |
+
self.llm_tokenizer.apply_chat_template(
|
| 119 |
+
self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"]),
|
| 120 |
+
tokenize=False,
|
| 121 |
+
add_generation_prompt=False,
|
| 122 |
+
enable_thinking=False,
|
| 123 |
+
)
|
| 124 |
+
if self.llm_type == "qwen3"
|
| 125 |
+
else self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
|
| 126 |
+
)
|
| 127 |
+
for one_text in text
|
| 128 |
+
]
|
| 129 |
+
padding_mode = "max_length" if self.enable_llm_padding else False
|
| 130 |
+
llm_batch_encoding = self.llm_tokenizer(
|
| 131 |
+
llm_text,
|
| 132 |
+
return_length=False,
|
| 133 |
+
return_overflowing_tokens=False,
|
| 134 |
+
truncation=True,
|
| 135 |
+
return_attention_mask=True,
|
| 136 |
+
max_length=self.max_length_llm, # = crop_start + _orig_max_length_llm
|
| 137 |
+
padding=padding_mode,
|
| 138 |
+
return_tensors="pt",
|
| 139 |
+
)
|
| 140 |
+
llm_outputs = (
|
| 141 |
+
self.llm_text_encoder(
|
| 142 |
+
input_ids=llm_batch_encoding["input_ids"].to(device),
|
| 143 |
+
attention_mask=llm_batch_encoding["attention_mask"].to(device),
|
| 144 |
+
output_hidden_states=True,
|
| 145 |
+
)
|
| 146 |
+
if self.llm_type == "qwen3"
|
| 147 |
+
else self.llm_text_encoder(
|
| 148 |
+
input_ids=llm_batch_encoding["input_ids"].to(device),
|
| 149 |
+
attention_mask=llm_batch_encoding["attention_mask"].to(device),
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
if self.llm_type == "qwen3":
|
| 153 |
+
ctxt_raw = llm_outputs.hidden_states[-1]
|
| 154 |
+
else:
|
| 155 |
+
ctxt_raw = llm_outputs.last_hidden_state
|
| 156 |
+
|
| 157 |
+
start = self.crop_start
|
| 158 |
+
end = start + self._orig_max_length_llm
|
| 159 |
+
ctxt_raw = ctxt_raw[:, start:end].contiguous() # [bs, _orig_max_length_llm, hidden]
|
| 160 |
+
ctxt_length = (llm_batch_encoding["attention_mask"].sum(dim=-1).to(device) - start).clamp(
|
| 161 |
+
min=0, max=self._orig_max_length_llm
|
| 162 |
+
)
|
| 163 |
+
return ctxt_raw, ctxt_length
|
| 164 |
+
|
| 165 |
+
@torch.no_grad()
|
| 166 |
+
def encode_sentence_emb(self, text: List[str]) -> Tensor:
|
| 167 |
+
if (
|
| 168 |
+
self.sentence_emb_type is None
|
| 169 |
+
or self.sentence_emb_text_encoder is None
|
| 170 |
+
or self.sentence_emb_tokenizer is None
|
| 171 |
+
):
|
| 172 |
+
raise ValueError("Sentence embedding model not initialized")
|
| 173 |
+
|
| 174 |
+
device = get_module_device(self)
|
| 175 |
+
enc = self.sentence_emb_tokenizer(
|
| 176 |
+
text,
|
| 177 |
+
return_length=False,
|
| 178 |
+
return_overflowing_tokens=False,
|
| 179 |
+
truncation=True,
|
| 180 |
+
return_attention_mask=True,
|
| 181 |
+
max_length=self.max_length_sentence_emb,
|
| 182 |
+
padding=True,
|
| 183 |
+
return_tensors="pt",
|
| 184 |
+
)
|
| 185 |
+
out = self.sentence_emb_text_encoder(
|
| 186 |
+
input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)
|
| 187 |
+
)
|
| 188 |
+
if self._sentence_emb_pooling_mode == "pooler_output":
|
| 189 |
+
# Pooler output pooling (clip-vit-large-patch14 等)
|
| 190 |
+
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 191 |
+
vtxt_raw = out.pooler_output.unsqueeze(1)
|
| 192 |
+
else:
|
| 193 |
+
vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
|
| 194 |
+
elif self._sentence_emb_pooling_mode == "mean":
|
| 195 |
+
vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
|
| 196 |
+
elif self._sentence_emb_pooling_mode == "last_token":
|
| 197 |
+
vtxt_raw = self._last_token_pool(out.last_hidden_state, enc["attention_mask"].to(device))
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(f"Unknown pooling mode: {self._sentence_emb_pooling_mode}")
|
| 200 |
+
|
| 201 |
+
return vtxt_raw
|
| 202 |
+
|
| 203 |
+
def encode(self, text: List[str]) -> Tuple[Tensor, Tensor, Tensor]:
|
| 204 |
+
ctxt_raw, ctxt_length = self.encode_llm(text=text)
|
| 205 |
+
vtxt_raw = self.encode_sentence_emb(text=text)
|
| 206 |
+
return vtxt_raw, ctxt_raw, ctxt_length
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def apply_text_to_template(text: str, template: Union[str, list]) -> Union[str, list]:
|
| 210 |
+
if isinstance(template, str):
|
| 211 |
+
return template.format(text)
|
| 212 |
+
elif isinstance(template, list):
|
| 213 |
+
return [
|
| 214 |
+
{"role": "system", "content": f"{template[0]['content']}"},
|
| 215 |
+
{"role": "user", "content": f"{text}"},
|
| 216 |
+
]
|
| 217 |
+
else:
|
| 218 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
| 219 |
+
|
| 220 |
+
def _compute_crop_start(self) -> int:
|
| 221 |
+
if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
|
| 222 |
+
raise ValueError("LLM model not initialized")
|
| 223 |
+
|
| 224 |
+
def _find_subseq(a: str, b: str) -> int:
|
| 225 |
+
for i in range(0, len(a) - len(b) + 1):
|
| 226 |
+
if a[i : i + len(b)] == b:
|
| 227 |
+
return i
|
| 228 |
+
return -1
|
| 229 |
+
|
| 230 |
+
marker = "<BOC>"
|
| 231 |
+
if self.llm_type == "qwen3":
|
| 232 |
+
msgs = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
|
| 233 |
+
s = self.llm_tokenizer.apply_chat_template(
|
| 234 |
+
msgs, tokenize=False, add_generation_prompt=False, enable_thinking=False
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
s = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
|
| 238 |
+
full_ids = self.llm_tokenizer(s, return_tensors="pt", add_special_tokens=True)["input_ids"][0].tolist()
|
| 239 |
+
marker_ids = self.llm_tokenizer(marker, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
|
| 240 |
+
pos = _find_subseq(full_ids, marker_ids)
|
| 241 |
+
if pos >= 0:
|
| 242 |
+
return pos
|
| 243 |
+
else:
|
| 244 |
+
return max(0, len(full_ids) - 1)
|
| 245 |
+
|
| 246 |
+
def _pad_or_truncate_tensor(self, tensor: Tensor, target_length: int, dim: int = 0) -> Tensor:
|
| 247 |
+
current_length = tensor.shape[dim]
|
| 248 |
+
if current_length > target_length:
|
| 249 |
+
return tensor.narrow(dim, 0, target_length)
|
| 250 |
+
elif current_length < target_length:
|
| 251 |
+
pad_shape = list(tensor.shape)
|
| 252 |
+
pad_shape[dim] = target_length - current_length
|
| 253 |
+
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor.narrow(dim, -1, 1)
|
| 254 |
+
return torch.cat([tensor, padding], dim=dim)
|
| 255 |
+
return tensor
|
| 256 |
+
|
| 257 |
+
def _encode_pooling(self, attention_mask: Tensor, token_embeddings: Tensor) -> Tensor:
|
| 258 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 259 |
+
sentence_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 260 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 261 |
+
)
|
| 262 |
+
vtxt_raw = nn.functional.normalize(sentence_embeddings, p=2, dim=1).unsqueeze(1) # shape of [bs, 1, D]
|
| 263 |
+
return vtxt_raw
|
| 264 |
+
|
| 265 |
+
def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
| 266 |
+
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
|
| 267 |
+
if left_padding:
|
| 268 |
+
vtxt_raw = last_hidden_states[:, -1]
|
| 269 |
+
else:
|
| 270 |
+
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 271 |
+
batch_size = last_hidden_states.shape[0]
|
| 272 |
+
vtxt_raw = last_hidden_states[
|
| 273 |
+
torch.arange(batch_size, device=last_hidden_states.device),
|
| 274 |
+
sequence_lengths,
|
| 275 |
+
]
|
| 276 |
+
vtxt_raw = nn.functional.normalize(vtxt_raw, p=2, dim=-1).unsqueeze(1) # shape of [bs, 1, D]
|
| 277 |
+
return vtxt_raw
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
# python -m hymotion.network.text_encoders.text_encoder
|
| 282 |
+
text_encoder = HYTextModel(llm_type="qwen3", max_length_llm=5)
|
| 283 |
+
vtxt_raw, ctxt_raw, ctxt_length = text_encoder.encode(["Hello, world!"])
|
| 284 |
+
print(vtxt_raw.shape, ctxt_raw.shape, ctxt_length)
|
| 285 |
+
|
| 286 |
+
crop_start = text_encoder._compute_crop_start()
|
| 287 |
+
print(f"crop_start: {crop_start} when using {text_encoder.llm_type}")
|
| 288 |
+
|
| 289 |
+
assert (
|
| 290 |
+
vtxt_raw.shape[1:] == (1, text_encoder.vtxt_dim)
|
| 291 |
+
and ctxt_raw.shape[1:] == (text_encoder._orig_max_length_llm, text_encoder.ctxt_dim)
|
| 292 |
+
and torch.all((ctxt_length >= 0) & (ctxt_length <= text_encoder._orig_max_length_llm))
|
| 293 |
+
), f"Got unexpected output shape: {vtxt_raw.shape}, {ctxt_raw.shape}, {ctxt_length}"
|
hymotion/network/token_refiner.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
from .attention import attention
|
| 9 |
+
from .bricks import get_norm_layer
|
| 10 |
+
from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
|
| 11 |
+
from .modulate_layers import ModulateDiT, apply_gate
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
feat_dim: int,
|
| 18 |
+
num_heads: int,
|
| 19 |
+
mlp_ratio: float = 4.0,
|
| 20 |
+
dropout: float = 0.0,
|
| 21 |
+
mlp_act_type: str = "silu",
|
| 22 |
+
qk_norm_type: str = "layer",
|
| 23 |
+
qkv_bias: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.feat_dim = feat_dim
|
| 27 |
+
self.num_heads = num_heads
|
| 28 |
+
self.mlp_ratio = mlp_ratio
|
| 29 |
+
self.dropout = dropout
|
| 30 |
+
assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
|
| 31 |
+
self.head_dim = feat_dim // num_heads
|
| 32 |
+
|
| 33 |
+
self.mlp_hidden_dim = int(feat_dim * mlp_ratio)
|
| 34 |
+
|
| 35 |
+
self.norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
|
| 36 |
+
self.self_attn_qkv = nn.Linear(feat_dim, feat_dim * 3, bias=qkv_bias)
|
| 37 |
+
self.self_attn_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 38 |
+
self.self_attn_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 39 |
+
self.self_attn_proj = nn.Linear(feat_dim, feat_dim, bias=qkv_bias)
|
| 40 |
+
|
| 41 |
+
self.norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
|
| 42 |
+
|
| 43 |
+
self.mlp = MLP(
|
| 44 |
+
in_dim=feat_dim,
|
| 45 |
+
feat_dim=self.mlp_hidden_dim,
|
| 46 |
+
act_type=mlp_act_type,
|
| 47 |
+
drop=dropout,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.adaLN_modulation = ModulateDiT(
|
| 51 |
+
feat_dim=feat_dim,
|
| 52 |
+
factor=2,
|
| 53 |
+
act_type="silu",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, c: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
| 57 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 58 |
+
norm_x = self.norm1(x)
|
| 59 |
+
qkv = self.self_attn_qkv(norm_x)
|
| 60 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
| 61 |
+
# Apply QK-Norm if needed
|
| 62 |
+
q = self.self_attn_q_norm(q).to(v)
|
| 63 |
+
k = self.self_attn_k_norm(k).to(v)
|
| 64 |
+
# Self-Attention
|
| 65 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
| 66 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
| 67 |
+
# FFN Layer
|
| 68 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class IndividualTokenRefiner(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
feat_dim: int,
|
| 76 |
+
num_heads: int,
|
| 77 |
+
num_layers: int,
|
| 78 |
+
mlp_ratio: float = 4.0,
|
| 79 |
+
dropout: float = 0.0,
|
| 80 |
+
mlp_act_type: str = "silu",
|
| 81 |
+
qk_norm_type: str = "layer",
|
| 82 |
+
qkv_bias: bool = True,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.blocks = nn.ModuleList(
|
| 86 |
+
[
|
| 87 |
+
IndividualTokenRefinerBlock(
|
| 88 |
+
feat_dim=feat_dim,
|
| 89 |
+
num_heads=num_heads,
|
| 90 |
+
mlp_ratio=mlp_ratio,
|
| 91 |
+
dropout=dropout,
|
| 92 |
+
mlp_act_type=mlp_act_type,
|
| 93 |
+
qk_norm_type=qk_norm_type,
|
| 94 |
+
qkv_bias=qkv_bias,
|
| 95 |
+
)
|
| 96 |
+
for _ in range(num_layers)
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def forward(self, x: Tensor, c: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
| 101 |
+
self_attn_mask = None
|
| 102 |
+
if mask is not None:
|
| 103 |
+
batch_size = mask.shape[0]
|
| 104 |
+
seq_len = mask.shape[1]
|
| 105 |
+
mask = mask.to(x.device)
|
| 106 |
+
# batch_size x 1 x seq_len x seq_len
|
| 107 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 108 |
+
# batch_size x 1 x seq_len x seq_len
|
| 109 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 110 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
|
| 111 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 112 |
+
# avoids self-attention weight being NaN for padding tokens
|
| 113 |
+
# assume the shape of self_attn_mask is [B, H, Q, K] and this is self-attention (Q==K==L)
|
| 114 |
+
L = self_attn_mask.size(-1)
|
| 115 |
+
diag = torch.eye(L, dtype=torch.bool, device=self_attn_mask.device).view(1, 1, L, L) # [1,1,L,L]
|
| 116 |
+
# mark which query row is "all False" (no visible key)
|
| 117 |
+
all_false = ~self_attn_mask.any(dim=-1, keepdim=False) # [B, H, Q]
|
| 118 |
+
# expand to [B, H, Q, K], only for these rows, back to diagonal visible
|
| 119 |
+
all_false = all_false.unsqueeze(-1).expand(-1, -1, -1, L)
|
| 120 |
+
self_attn_mask = torch.where(all_false, diag.expand_as(self_attn_mask), self_attn_mask)
|
| 121 |
+
|
| 122 |
+
if self_attn_mask is not None:
|
| 123 |
+
self_attn_mask = torch.where(
|
| 124 |
+
self_attn_mask,
|
| 125 |
+
torch.zeros_like(self_attn_mask, dtype=torch.float),
|
| 126 |
+
torch.full_like(self_attn_mask, float("-inf"), dtype=torch.float),
|
| 127 |
+
)
|
| 128 |
+
for block in self.blocks:
|
| 129 |
+
x = block(x, c, self_attn_mask)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class SingleTokenRefiner(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
input_dim: int,
|
| 137 |
+
feat_dim: int,
|
| 138 |
+
num_heads: int,
|
| 139 |
+
num_layers: int,
|
| 140 |
+
mlp_ratio: float = 4.0,
|
| 141 |
+
dropout: float = 0.0,
|
| 142 |
+
mlp_act_type: str = "silu",
|
| 143 |
+
qk_norm_type: str = "layer",
|
| 144 |
+
qkv_bias: bool = True,
|
| 145 |
+
attn_mode: str = "torch",
|
| 146 |
+
**kwargs,
|
| 147 |
+
) -> None:
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.attn_mode = attn_mode
|
| 150 |
+
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
| 151 |
+
|
| 152 |
+
self.input_embedder = nn.Linear(input_dim, feat_dim, bias=True)
|
| 153 |
+
self.context_encoder = MLPEncoder(
|
| 154 |
+
in_dim=feat_dim,
|
| 155 |
+
feat_dim=feat_dim,
|
| 156 |
+
num_layers=2,
|
| 157 |
+
act_type=mlp_act_type,
|
| 158 |
+
)
|
| 159 |
+
self.timestep_encoder = TimestepEmbeddingEncoder(
|
| 160 |
+
embedding_dim=feat_dim,
|
| 161 |
+
feat_dim=feat_dim,
|
| 162 |
+
act_type=mlp_act_type,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
| 166 |
+
feat_dim=feat_dim,
|
| 167 |
+
num_heads=num_heads,
|
| 168 |
+
num_layers=num_layers,
|
| 169 |
+
mlp_ratio=mlp_ratio,
|
| 170 |
+
dropout=dropout,
|
| 171 |
+
mlp_act_type=mlp_act_type,
|
| 172 |
+
qk_norm_type=qk_norm_type,
|
| 173 |
+
qkv_bias=qkv_bias,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
| 177 |
+
timestep_aware_representations = self.timestep_encoder(t)
|
| 178 |
+
|
| 179 |
+
if mask is None:
|
| 180 |
+
context_aware_representations = x.mean(dim=1)
|
| 181 |
+
else:
|
| 182 |
+
mask_float = mask.float().unsqueeze(-1)
|
| 183 |
+
denom = mask_float.sum(dim=1).clamp_min(1e-6)
|
| 184 |
+
context_aware_representations = (x * mask_float).sum(dim=1) / denom
|
| 185 |
+
context_aware_representations = self.context_encoder(context_aware_representations).unsqueeze(1)
|
| 186 |
+
c = timestep_aware_representations + context_aware_representations
|
| 187 |
+
|
| 188 |
+
x = self.input_embedder(x)
|
| 189 |
+
|
| 190 |
+
x = self.individual_token_refiner(x, c, mask)
|
| 191 |
+
|
| 192 |
+
return x
|
hymotion/pipeline/body_model.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from ..utils.geometry import (
|
| 11 |
+
rot6d_to_rotation_matrix,
|
| 12 |
+
rotation_matrix_to_angle_axis,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# yapf: disable
|
| 16 |
+
LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909,
|
| 17 |
+
-0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192,
|
| 18 |
+
-0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069,
|
| 19 |
+
-0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579,
|
| 20 |
+
-0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998,
|
| 21 |
+
0.0266, 0.0529, 0.5356, 0.0460, -0.2774]
|
| 22 |
+
RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909,
|
| 23 |
+
0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192,
|
| 24 |
+
0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069,
|
| 25 |
+
-0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579,
|
| 26 |
+
0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998,
|
| 27 |
+
-0.0266, -0.0529, 0.5356, -0.0460, 0.2774]
|
| 28 |
+
# yapf: enable
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_tensor(array, dtype=torch.float32, device=torch.device("cpu")):
|
| 32 |
+
if "torch.tensor" not in str(type(array)):
|
| 33 |
+
return torch.tensor(array, dtype=dtype).to(device)
|
| 34 |
+
else:
|
| 35 |
+
return array.to(device)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
| 39 |
+
"""Calculates the rotation matrices for a batch of rotation vectors
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
rot_vecs: torch.tensor Nx3
|
| 43 |
+
array of N axis-angle vectors
|
| 44 |
+
Returns
|
| 45 |
+
-------
|
| 46 |
+
R: torch.tensor Nx3x3
|
| 47 |
+
The rotation matrices for the given axis-angle parameters
|
| 48 |
+
"""
|
| 49 |
+
if len(rot_vecs.shape) > 2:
|
| 50 |
+
rot_vec_ori = rot_vecs
|
| 51 |
+
rot_vecs = rot_vecs.view(-1, 3)
|
| 52 |
+
else:
|
| 53 |
+
rot_vec_ori = None
|
| 54 |
+
batch_size = rot_vecs.shape[0]
|
| 55 |
+
device = rot_vecs.device
|
| 56 |
+
|
| 57 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
| 58 |
+
rot_dir = rot_vecs / angle
|
| 59 |
+
|
| 60 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
| 61 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
| 62 |
+
|
| 63 |
+
# Bx1 arrays
|
| 64 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
| 65 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
| 66 |
+
|
| 67 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
| 68 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
|
| 69 |
+
|
| 70 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
| 71 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
| 72 |
+
if rot_vec_ori is not None:
|
| 73 |
+
rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3)
|
| 74 |
+
return rot_mat
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_model_data(model_path):
|
| 78 |
+
"""
|
| 79 |
+
Load wooden model data from binary files.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model_path: path to the directory containing .bin files
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
dict containing:
|
| 86 |
+
- v_template: (V, 3) vertex template
|
| 87 |
+
- j_template: (J, 3) joint template
|
| 88 |
+
- skin_weights: (V, 4) skin weights
|
| 89 |
+
- skin_indices: (V, 4) skin indices
|
| 90 |
+
- parents: (J,) parent indices (kintree)
|
| 91 |
+
- faces: (F, 3) face indices
|
| 92 |
+
- joint_names: list of joint names
|
| 93 |
+
"""
|
| 94 |
+
model_path = Path(model_path)
|
| 95 |
+
|
| 96 |
+
# Load vertex template: (V*3,) -> (V, 3)
|
| 97 |
+
with open(model_path / "v_template.bin", "rb") as f:
|
| 98 |
+
v_template_flat = np.frombuffer(f.read(), dtype=np.float32)
|
| 99 |
+
num_verts = len(v_template_flat) // 3
|
| 100 |
+
v_template = v_template_flat.reshape(num_verts, 3)
|
| 101 |
+
|
| 102 |
+
# Load joint template: (J*3,) -> (J, 3)
|
| 103 |
+
with open(model_path / "j_template.bin", "rb") as f:
|
| 104 |
+
j_template_flat = np.frombuffer(f.read(), dtype=np.float32)
|
| 105 |
+
num_joints = len(j_template_flat) // 3
|
| 106 |
+
j_template = j_template_flat.reshape(num_joints, 3)
|
| 107 |
+
|
| 108 |
+
# Load skin weights: (V*4,) -> (V, 4), 4 bones per vertex
|
| 109 |
+
with open(model_path / "skinWeights.bin", "rb") as f:
|
| 110 |
+
skin_weights_flat = np.frombuffer(f.read(), dtype=np.float32)
|
| 111 |
+
skin_weights = skin_weights_flat.reshape(num_verts, 4)
|
| 112 |
+
|
| 113 |
+
# Load skin indices: (V*4,) -> (V, 4), 4 bone indices per vertex
|
| 114 |
+
with open(model_path / "skinIndice.bin", "rb") as f:
|
| 115 |
+
skin_indices_flat = np.frombuffer(f.read(), dtype=np.uint16)
|
| 116 |
+
skin_indices = skin_indices_flat.reshape(num_verts, 4).astype(np.int64)
|
| 117 |
+
|
| 118 |
+
# Load kintree (parent indices): (J,)
|
| 119 |
+
with open(model_path / "kintree.bin", "rb") as f:
|
| 120 |
+
parents = np.frombuffer(f.read(), dtype=np.int32)
|
| 121 |
+
|
| 122 |
+
# Load faces
|
| 123 |
+
with open(model_path / "faces.bin", "rb") as f:
|
| 124 |
+
faces_flat = np.frombuffer(f.read(), dtype=np.uint16)
|
| 125 |
+
faces = faces_flat.reshape(-1, 3)
|
| 126 |
+
|
| 127 |
+
# Load joint names
|
| 128 |
+
joint_names_path = model_path / "joint_names.json"
|
| 129 |
+
if joint_names_path.exists():
|
| 130 |
+
with open(joint_names_path, "r") as f:
|
| 131 |
+
joint_names = json.load(f)
|
| 132 |
+
else:
|
| 133 |
+
joint_names = [f"Joint_{i}" for i in range(num_joints)]
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"v_template": v_template,
|
| 137 |
+
"j_template": j_template,
|
| 138 |
+
"skin_weights": skin_weights,
|
| 139 |
+
"skin_indices": skin_indices,
|
| 140 |
+
"parents": parents,
|
| 141 |
+
"faces": faces,
|
| 142 |
+
"joint_names": joint_names,
|
| 143 |
+
"num_joints": num_joints,
|
| 144 |
+
"num_verts": num_verts,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def simple_lbs(v_template, rot_mats, joints, parents, skin_weights, skin_indices):
|
| 149 |
+
"""
|
| 150 |
+
Simple Linear Blend Skinning without shape blending.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
v_template: (V, 3) template vertices
|
| 154 |
+
rot_mats: (B, J, 3, 3) rotation matrices for each joint
|
| 155 |
+
joints: (J, 3) joint positions in rest pose
|
| 156 |
+
parents: (J,) parent indices for each joint
|
| 157 |
+
skin_weights: (V, 4) skin weights for 4 bones per vertex
|
| 158 |
+
skin_indices: (V, 4) bone indices for 4 bones per vertex
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
vertices: (B, V, 3) transformed vertices
|
| 162 |
+
posed_joints: (B, J, 3) transformed joint positions
|
| 163 |
+
"""
|
| 164 |
+
batch_size = rot_mats.shape[0]
|
| 165 |
+
num_joints = rot_mats.shape[1]
|
| 166 |
+
num_verts = v_template.shape[0]
|
| 167 |
+
device = rot_mats.device
|
| 168 |
+
dtype = rot_mats.dtype
|
| 169 |
+
|
| 170 |
+
# Compute relative joint positions
|
| 171 |
+
rel_joints = joints.clone()
|
| 172 |
+
rel_joints[1:] = joints[1:] - joints[parents[1:]]
|
| 173 |
+
|
| 174 |
+
# Build transformation chain: transforms_mat (B, J, 4, 4)
|
| 175 |
+
transforms_mat = torch.zeros(batch_size, num_joints, 4, 4, device=device, dtype=dtype)
|
| 176 |
+
transforms_mat[..., :3, :3] = rot_mats
|
| 177 |
+
transforms_mat[..., :3, 3] = rel_joints.unsqueeze(0).expand(batch_size, -1, -1)
|
| 178 |
+
transforms_mat[..., 3, 3] = 1.0
|
| 179 |
+
|
| 180 |
+
# Forward kinematics: accumulate transforms from root to each joint
|
| 181 |
+
transform_chain = [transforms_mat[:, 0]]
|
| 182 |
+
for i in range(1, num_joints):
|
| 183 |
+
parent_idx = parents[i].item()
|
| 184 |
+
curr_transform = torch.bmm(transform_chain[parent_idx], transforms_mat[:, i])
|
| 185 |
+
transform_chain.append(curr_transform)
|
| 186 |
+
|
| 187 |
+
transforms = torch.stack(transform_chain, dim=1) # (B, J, 4, 4)
|
| 188 |
+
|
| 189 |
+
# Get posed joint positions
|
| 190 |
+
posed_joints = transforms[..., :3, 3].clone() # (B, J, 3)
|
| 191 |
+
|
| 192 |
+
# Compute relative transforms (for skinning)
|
| 193 |
+
# We need to subtract the rest pose joint positions from the transform
|
| 194 |
+
rel_transforms = transforms.clone()
|
| 195 |
+
joints_homo = F.pad(joints, [0, 1], value=0) # (J, 4)
|
| 196 |
+
transformed_rest = torch.einsum("bjcd,jd->bjc", transforms[..., :3, :], joints_homo)
|
| 197 |
+
rel_transforms[..., :3, 3] = transforms[..., :3, 3] - transformed_rest[..., :3]
|
| 198 |
+
|
| 199 |
+
# Apply skinning: gather transforms for each vertex's 4 bones
|
| 200 |
+
# skin_indices: (V, 4), skin_weights: (V, 4)
|
| 201 |
+
vertex_transforms = torch.zeros(batch_size, num_verts, 4, 4, 4, device=device, dtype=dtype)
|
| 202 |
+
for k in range(4):
|
| 203 |
+
bone_idx = skin_indices[:, k].long() # (V,)
|
| 204 |
+
vertex_transforms[:, :, k] = rel_transforms[:, bone_idx] # (B, V, 4, 4)
|
| 205 |
+
|
| 206 |
+
# Weight the transforms
|
| 207 |
+
skin_weights_expanded = skin_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1, V, 4, 1, 1)
|
| 208 |
+
skin_weights_expanded = skin_weights_expanded.expand(batch_size, -1, -1, 4, 4) # (B, V, 4, 4, 4)
|
| 209 |
+
|
| 210 |
+
weighted_transforms = (vertex_transforms * skin_weights_expanded).sum(dim=2) # (B, V, 4, 4)
|
| 211 |
+
|
| 212 |
+
# Apply to vertices
|
| 213 |
+
v_homo = F.pad(v_template, [0, 1], value=1.0) # (V, 4)
|
| 214 |
+
vertices = torch.einsum("bvcd,vd->bvc", weighted_transforms[..., :3, :], v_homo) # (B, V, 3)
|
| 215 |
+
|
| 216 |
+
return vertices, posed_joints
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class WoodenMesh(torch.nn.Module):
|
| 220 |
+
"""
|
| 221 |
+
Wooden character mesh model that loads from binary files.
|
| 222 |
+
Uses simple LBS without shape blending (fixed skeleton).
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self, model_path="scripts/gradio/static/assets/dump_wooden"):
|
| 226 |
+
torch.nn.Module.__init__(self)
|
| 227 |
+
|
| 228 |
+
# Load model data from .bin files
|
| 229 |
+
model = load_model_data(model_path)
|
| 230 |
+
|
| 231 |
+
# Register buffers like original SMPLMesh
|
| 232 |
+
v_template = to_tensor(model["v_template"])
|
| 233 |
+
self.register_buffer("v_template", v_template)
|
| 234 |
+
|
| 235 |
+
j_template = to_tensor(model["j_template"])
|
| 236 |
+
self.register_buffer("j_template", j_template)
|
| 237 |
+
|
| 238 |
+
skin_weights = to_tensor(model["skin_weights"])
|
| 239 |
+
self.register_buffer("skin_weights", skin_weights)
|
| 240 |
+
|
| 241 |
+
skin_indices = to_tensor(model["skin_indices"], dtype=torch.long)
|
| 242 |
+
self.register_buffer("skin_indices", skin_indices)
|
| 243 |
+
|
| 244 |
+
parents = to_tensor(model["parents"], dtype=torch.long)
|
| 245 |
+
self.register_buffer("parents", parents)
|
| 246 |
+
|
| 247 |
+
# Store non-buffer attributes
|
| 248 |
+
self.faces = model["faces"]
|
| 249 |
+
self.joint_names = model["joint_names"]
|
| 250 |
+
self.num_joints = model["num_joints"]
|
| 251 |
+
self.num_verts = model["num_verts"]
|
| 252 |
+
|
| 253 |
+
print(f"[WoodenMesh] Loaded model: {self.num_verts} vertices, {self.num_joints} joints")
|
| 254 |
+
|
| 255 |
+
def forward(self, params, fast_forward=False):
|
| 256 |
+
"""
|
| 257 |
+
Forward pass to compute deformed vertices.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
params: dict containing:
|
| 261 |
+
- 'poses': (B, J*3) axis-angle rotations, or
|
| 262 |
+
- 'rot6d': (B, J, 6) 6D rotation representations
|
| 263 |
+
- 'trans': (B, 3) optional translation
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
dict with 'vertices' and 'vertices_wotrans'
|
| 267 |
+
"""
|
| 268 |
+
if "poses" in params:
|
| 269 |
+
poses = params["poses"]
|
| 270 |
+
batch_size = poses.shape[0]
|
| 271 |
+
rot_mats = batch_rodrigues(poses.view(-1, 3)).view([batch_size, -1, 3, 3])
|
| 272 |
+
elif "rot6d" in params:
|
| 273 |
+
rot6d = params["rot6d"]
|
| 274 |
+
batch_size = rot6d.shape[0]
|
| 275 |
+
rot_mats = rot6d_to_rotation_matrix(rot6d).view([batch_size, -1, 3, 3])
|
| 276 |
+
else:
|
| 277 |
+
raise ValueError("poses or rot6d must be in params")
|
| 278 |
+
|
| 279 |
+
if rot_mats.shape[1] == 22:
|
| 280 |
+
eye = torch.eye(3, device=rot_mats.device, dtype=rot_mats.dtype)[None, None, :, :].repeat(
|
| 281 |
+
batch_size, 30, 1, 1
|
| 282 |
+
)
|
| 283 |
+
rot_mats = torch.cat([rot_mats, eye], dim=1) # (B, 22 + 30, 3, 3)
|
| 284 |
+
|
| 285 |
+
# Simple LBS (no shape blending, fixed skeleton)
|
| 286 |
+
vertices, posed_joints = simple_lbs(
|
| 287 |
+
self.v_template,
|
| 288 |
+
rot_mats,
|
| 289 |
+
self.j_template,
|
| 290 |
+
self.parents,
|
| 291 |
+
self.skin_weights,
|
| 292 |
+
self.skin_indices,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Vertices without translation (for pose-level supervision)
|
| 296 |
+
vertices_wotrans = vertices
|
| 297 |
+
|
| 298 |
+
if "trans" in params:
|
| 299 |
+
trans = params["trans"]
|
| 300 |
+
vertices = vertices + trans[:, None, :]
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"vertices": vertices,
|
| 304 |
+
"vertices_wotrans": vertices_wotrans,
|
| 305 |
+
"keypoints3d": posed_joints,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def forward_batch(self, params):
|
| 309 |
+
assert "rot6d" in params and "trans" in params
|
| 310 |
+
rot6d = params["rot6d"]
|
| 311 |
+
trans = params["trans"]
|
| 312 |
+
bs, num_frames = rot6d.shape[:2]
|
| 313 |
+
rot6d_flat = rot6d.reshape(bs * num_frames, rot6d.shape[2], rot6d.shape[3])
|
| 314 |
+
trans_flat = trans.reshape(bs * num_frames, trans.shape[2])
|
| 315 |
+
result = self.forward(
|
| 316 |
+
{
|
| 317 |
+
"rot6d": rot6d_flat,
|
| 318 |
+
"trans": trans_flat,
|
| 319 |
+
}
|
| 320 |
+
)
|
| 321 |
+
out = {}
|
| 322 |
+
for key in result:
|
| 323 |
+
out[key] = result[key].reshape(bs, num_frames, *result[key].shape[1:])
|
| 324 |
+
return out
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def construct_smpl_data_dict(
|
| 328 |
+
rot6d: Tensor,
|
| 329 |
+
transl: Tensor,
|
| 330 |
+
betas: Optional[Tensor] = None,
|
| 331 |
+
gender: str = "neutral",
|
| 332 |
+
use_default_hand_mean_pose: bool = False,
|
| 333 |
+
) -> dict:
|
| 334 |
+
rotation_matrix = rot6d_to_rotation_matrix(rot6d)
|
| 335 |
+
angle_axis = rotation_matrix_to_angle_axis(rotation_matrix)
|
| 336 |
+
left_hand_mean_pose = (
|
| 337 |
+
torch.tensor(
|
| 338 |
+
LEFT_HAND_MEAN_AA,
|
| 339 |
+
device=angle_axis.device,
|
| 340 |
+
dtype=angle_axis.dtype,
|
| 341 |
+
)
|
| 342 |
+
.unsqueeze(0)
|
| 343 |
+
.repeat(angle_axis.shape[0], 1)
|
| 344 |
+
.reshape(angle_axis.shape[0], -1, 3)
|
| 345 |
+
)
|
| 346 |
+
right_hand_mean_pose = (
|
| 347 |
+
torch.tensor(
|
| 348 |
+
RIGHT_HAND_MEAN_AA,
|
| 349 |
+
device=angle_axis.device,
|
| 350 |
+
dtype=angle_axis.dtype,
|
| 351 |
+
)
|
| 352 |
+
.unsqueeze(0)
|
| 353 |
+
.repeat(angle_axis.shape[0], 1)
|
| 354 |
+
.reshape(angle_axis.shape[0], -1, 3)
|
| 355 |
+
)
|
| 356 |
+
if angle_axis.shape[1] == 22:
|
| 357 |
+
angle_axis = torch.cat(
|
| 358 |
+
[
|
| 359 |
+
angle_axis,
|
| 360 |
+
left_hand_mean_pose,
|
| 361 |
+
right_hand_mean_pose,
|
| 362 |
+
],
|
| 363 |
+
dim=1,
|
| 364 |
+
)
|
| 365 |
+
elif angle_axis.shape[1] == 52:
|
| 366 |
+
if use_default_hand_mean_pose:
|
| 367 |
+
angle_axis = torch.cat(
|
| 368 |
+
[
|
| 369 |
+
angle_axis[:, :22],
|
| 370 |
+
left_hand_mean_pose,
|
| 371 |
+
right_hand_mean_pose,
|
| 372 |
+
],
|
| 373 |
+
dim=1,
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
angle_axis = angle_axis
|
| 377 |
+
|
| 378 |
+
assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}"
|
| 379 |
+
dump = {
|
| 380 |
+
"betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)),
|
| 381 |
+
"gender": gender,
|
| 382 |
+
"poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1),
|
| 383 |
+
"trans": transl.cpu().numpy(),
|
| 384 |
+
"mocap_framerate": 30,
|
| 385 |
+
"num_frames": angle_axis.shape[0],
|
| 386 |
+
"Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3],
|
| 387 |
+
}
|
| 388 |
+
return dump
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
if __name__ == "__main__":
|
| 392 |
+
# python -m hymotion.pipeline.body_model
|
| 393 |
+
model_path = "scripts/gradio/static/assets/dump_wooden"
|
| 394 |
+
model = WoodenMesh(model_path)
|
| 395 |
+
params = {
|
| 396 |
+
"rot6d": torch.randn(1, 52, 6),
|
| 397 |
+
"trans": torch.randn(1, 3),
|
| 398 |
+
}
|
| 399 |
+
result = model(params)
|
| 400 |
+
print(result.keys())
|
| 401 |
+
print(result["vertices"].shape)
|
| 402 |
+
print(result["vertices_wotrans"].shape)
|
| 403 |
+
print(result["keypoints3d"].shape)
|
| 404 |
+
params_batch = {
|
| 405 |
+
"rot6d": torch.randn(3, 100, 22, 6),
|
| 406 |
+
"trans": torch.randn(3, 100, 3),
|
| 407 |
+
}
|
| 408 |
+
result_batch = model.forward_batch(params_batch)
|
| 409 |
+
print(result_batch.keys())
|
| 410 |
+
print(result_batch["vertices"].shape)
|
| 411 |
+
print(result_batch["vertices_wotrans"].shape)
|
| 412 |
+
print(result_batch["keypoints3d"].shape)
|
hymotion/pipeline/motion_diffusion.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from scipy.signal import savgol_filter
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torchdiffeq import odeint
|
| 11 |
+
|
| 12 |
+
from ..utils.geometry import (
|
| 13 |
+
matrix_to_quaternion,
|
| 14 |
+
quaternion_fix_continuity,
|
| 15 |
+
quaternion_to_matrix,
|
| 16 |
+
rot6d_to_rotation_matrix,
|
| 17 |
+
rotation_matrix_to_rot6d,
|
| 18 |
+
)
|
| 19 |
+
from ..utils.loaders import load_object
|
| 20 |
+
from ..utils.motion_process import smooth_rotation
|
| 21 |
+
from ..utils.type_converter import get_module_device
|
| 22 |
+
from .body_model import WoodenMesh
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def length_to_mask(lengths: Tensor, max_len: int) -> Tensor:
|
| 26 |
+
"""
|
| 27 |
+
lengths: (B, 1)
|
| 28 |
+
max_len: int
|
| 29 |
+
Returns: (B, max_len)
|
| 30 |
+
"""
|
| 31 |
+
assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
|
| 32 |
+
if lengths.ndim == 1:
|
| 33 |
+
lengths = lengths.unsqueeze(1)
|
| 34 |
+
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths
|
| 35 |
+
return mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def start_end_frame_to_mask(start_frame: Tensor, end_frame: Tensor, max_len: int) -> Tensor:
|
| 39 |
+
assert (start_frame >= 0).all() and (end_frame >= 0).all(), f"start_frame={start_frame}, end_frame={end_frame}"
|
| 40 |
+
lengths = end_frame - start_frame + 1
|
| 41 |
+
assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
|
| 42 |
+
if lengths.ndim == 1:
|
| 43 |
+
lengths = lengths.unsqueeze(1)
|
| 44 |
+
batch_size = start_frame.shape[0]
|
| 45 |
+
arange_ids = torch.arange(max_len, device=start_frame.device).unsqueeze(0).expand(batch_size, max_len)
|
| 46 |
+
mask = (arange_ids >= start_frame.unsqueeze(1)) & (arange_ids <= end_frame.unsqueeze(1))
|
| 47 |
+
return mask
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def randn_tensor(
|
| 51 |
+
shape,
|
| 52 |
+
generator=None,
|
| 53 |
+
device=None,
|
| 54 |
+
dtype=None,
|
| 55 |
+
layout=None,
|
| 56 |
+
):
|
| 57 |
+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`.
|
| 58 |
+
|
| 59 |
+
When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the
|
| 60 |
+
tensor is always created on the CPU.
|
| 61 |
+
"""
|
| 62 |
+
# device on which tensor is created defaults to device
|
| 63 |
+
rand_device = device
|
| 64 |
+
batch_size = shape[0]
|
| 65 |
+
|
| 66 |
+
layout = layout or torch.strided
|
| 67 |
+
device = device or torch.device("cpu")
|
| 68 |
+
|
| 69 |
+
if generator is not None:
|
| 70 |
+
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
|
| 71 |
+
if gen_device_type != device.type and gen_device_type == "cpu":
|
| 72 |
+
rand_device = "cpu"
|
| 73 |
+
if device != "mps":
|
| 74 |
+
print(
|
| 75 |
+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
| 76 |
+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
| 77 |
+
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
| 78 |
+
)
|
| 79 |
+
elif gen_device_type != device.type and gen_device_type == "cuda":
|
| 80 |
+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
| 81 |
+
|
| 82 |
+
# make sure generator list of length 1 is treated like a non-list
|
| 83 |
+
if isinstance(generator, list) and len(generator) == 1:
|
| 84 |
+
generator = generator[0]
|
| 85 |
+
|
| 86 |
+
if isinstance(generator, list):
|
| 87 |
+
shape = (1,) + shape[1:]
|
| 88 |
+
latents = [
|
| 89 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
| 90 |
+
for i in range(batch_size)
|
| 91 |
+
]
|
| 92 |
+
latents = torch.cat(latents, dim=0).to(device)
|
| 93 |
+
else:
|
| 94 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
| 95 |
+
|
| 96 |
+
return latents
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class MotionGeneration(torch.nn.Module):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
network_module: str,
|
| 103 |
+
network_module_args: dict,
|
| 104 |
+
text_encoder_module: str,
|
| 105 |
+
text_encoder_cfg: dict,
|
| 106 |
+
mean_std_dir: str,
|
| 107 |
+
motion_type="auto",
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
# build models and parameters
|
| 112 |
+
self._network_module_args = deepcopy(network_module_args)
|
| 113 |
+
self.motion_transformer = load_object(network_module, network_module_args)
|
| 114 |
+
self._text_encoder_module = text_encoder_module
|
| 115 |
+
self._text_encoder_cfg = deepcopy(text_encoder_cfg)
|
| 116 |
+
self.motion_type = motion_type
|
| 117 |
+
|
| 118 |
+
self.null_vtxt_feat = torch.nn.Parameter(
|
| 119 |
+
torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
|
| 120 |
+
)
|
| 121 |
+
self.null_ctxt_input = torch.nn.Parameter(
|
| 122 |
+
torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
|
| 123 |
+
)
|
| 124 |
+
self.special_game_vtxt_feat = torch.nn.Parameter(
|
| 125 |
+
torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
|
| 126 |
+
)
|
| 127 |
+
self.special_game_ctxt_feat = torch.nn.Parameter(
|
| 128 |
+
torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
|
| 129 |
+
)
|
| 130 |
+
# build buffer
|
| 131 |
+
self.mean_std_dir = mean_std_dir
|
| 132 |
+
self._parse_buffer(self.motion_type)
|
| 133 |
+
|
| 134 |
+
self.output_mesh_fps = kwargs.get("output_mesh_fps", 30)
|
| 135 |
+
self.train_frames = kwargs.get("train_frames", 360)
|
| 136 |
+
self.uncondition_mode = kwargs.get("uncondition_mode", False)
|
| 137 |
+
self.enable_ctxt_null_feat = kwargs.get("enable_ctxt_null_feat", False)
|
| 138 |
+
self.enable_special_game_feat = kwargs.get("enable_special_game_feat", False)
|
| 139 |
+
self.random_generator_on_gpu = kwargs.get("random_generator_on_gpu", True)
|
| 140 |
+
|
| 141 |
+
def _parse_buffer(self, mode: str) -> None:
|
| 142 |
+
self.body_model = WoodenMesh()
|
| 143 |
+
self._find_motion_type(mode=mode)
|
| 144 |
+
self._load_mean_std()
|
| 145 |
+
|
| 146 |
+
def _load_mean_std(self, mean_std_name: Optional[str] = None) -> None:
|
| 147 |
+
mean_std_name = self.mean_std_dir if mean_std_name is None else mean_std_name
|
| 148 |
+
if mean_std_name is not None and osp.isdir(mean_std_name):
|
| 149 |
+
mean = torch.from_numpy(np.load(osp.join(mean_std_name, "Mean.npy"))).float()
|
| 150 |
+
std = torch.from_numpy(np.load(osp.join(mean_std_name, "Std.npy"))).float()
|
| 151 |
+
self._assert_motion_dimension(mean.unsqueeze(0), std.unsqueeze(0))
|
| 152 |
+
self.register_buffer("mean", mean)
|
| 153 |
+
self.register_buffer("std", std)
|
| 154 |
+
else:
|
| 155 |
+
print(
|
| 156 |
+
f"[{self.__class__.__name__}] No mean_std found, using blank mean_std, "
|
| 157 |
+
f"self.mean_std_dir={self.mean_std_dir}"
|
| 158 |
+
)
|
| 159 |
+
self.register_buffer("mean", torch.zeros(1))
|
| 160 |
+
self.register_buffer("std", torch.ones(1))
|
| 161 |
+
|
| 162 |
+
def _assert_motion_dimension(self, mean: Tensor, std: Tensor) -> None:
|
| 163 |
+
assert mean.shape == std.shape, f"mean.shape={mean.shape} != std.shape={std.shape}"
|
| 164 |
+
assert mean.ndim == 2, f"mean.ndim={mean.ndim} != 2"
|
| 165 |
+
assert mean.shape == (1, 201), f"mean.shape={mean.shape} != (1, 201)"
|
| 166 |
+
|
| 167 |
+
def _find_motion_type(self, mode: str) -> None:
|
| 168 |
+
if mode == "auto":
|
| 169 |
+
self.motion_type = "o6dp"
|
| 170 |
+
else:
|
| 171 |
+
self.motion_type = mode
|
| 172 |
+
|
| 173 |
+
def set_epoch(self, epoch) -> None:
|
| 174 |
+
self.current_epoch = epoch
|
| 175 |
+
|
| 176 |
+
def load_in_demo(
|
| 177 |
+
self,
|
| 178 |
+
ckpt_name: str,
|
| 179 |
+
mean_std_name: Optional[str] = None,
|
| 180 |
+
build_text_encoder: bool = True,
|
| 181 |
+
allow_empty_ckpt: bool = False,
|
| 182 |
+
) -> None:
|
| 183 |
+
if not allow_empty_ckpt:
|
| 184 |
+
if not os.path.exists(ckpt_name):
|
| 185 |
+
import warnings
|
| 186 |
+
|
| 187 |
+
warnings.warn(f"Checkpoint {ckpt_name} not found, skipping model loading")
|
| 188 |
+
else:
|
| 189 |
+
checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
|
| 190 |
+
self.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 191 |
+
if mean_std_name is not None:
|
| 192 |
+
assert os.path.exists(mean_std_name), f"{mean_std_name} not found"
|
| 193 |
+
if not os.path.isfile(mean_std_name):
|
| 194 |
+
mean_std_name = None
|
| 195 |
+
self._load_mean_std(mean_std_name)
|
| 196 |
+
self.motion_transformer.eval()
|
| 197 |
+
if build_text_encoder and not self.uncondition_mode:
|
| 198 |
+
self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
|
| 199 |
+
self.text_encoder.to(get_module_device(self))
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def encode_text(self, text: Dict[str, List[str]]) -> Dict[str, Tensor]:
|
| 203 |
+
if not hasattr(self, "text_encoder"):
|
| 204 |
+
self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
|
| 205 |
+
self.text_encoder.to(get_module_device(self))
|
| 206 |
+
text = text["text"]
|
| 207 |
+
vtxt_input, ctxt_input, ctxt_length = self.text_encoder.encode(text=text)
|
| 208 |
+
return {
|
| 209 |
+
"text_vec_raw": vtxt_input,
|
| 210 |
+
"text_ctxt_raw": ctxt_input,
|
| 211 |
+
"text_ctxt_raw_length": ctxt_length,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
def decode_motion_from_latent(self, latent: Tensor, should_apply_smooothing: bool = True) -> Dict[str, Tensor]:
|
| 215 |
+
std_zero = self.std < 1e-3
|
| 216 |
+
std = torch.where(std_zero, torch.zeros_like(self.std), self.std)
|
| 217 |
+
latent_denorm = latent * std + self.mean
|
| 218 |
+
return self._decode_o6dp(
|
| 219 |
+
latent_denorm,
|
| 220 |
+
num_joints=22,
|
| 221 |
+
rel_trans=False,
|
| 222 |
+
should_apply_smooothing=should_apply_smooothing,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def _decode_o6dp(
|
| 226 |
+
self,
|
| 227 |
+
latent_denorm: torch.Tensor,
|
| 228 |
+
num_joints: int,
|
| 229 |
+
rel_trans: bool = False,
|
| 230 |
+
should_apply_smooothing: bool = True,
|
| 231 |
+
) -> dict:
|
| 232 |
+
device = get_module_device(self)
|
| 233 |
+
B, L = latent_denorm.shape[:2]
|
| 234 |
+
nj = num_joints
|
| 235 |
+
body_n = nj - 1
|
| 236 |
+
|
| 237 |
+
if not rel_trans:
|
| 238 |
+
transl = latent_denorm[..., 0:3].clone()
|
| 239 |
+
else:
|
| 240 |
+
transl = torch.cumsum(latent_denorm[..., 0:3].clone(), dim=1) / self.output_mesh_fps
|
| 241 |
+
root_rot6d = latent_denorm[..., 3:9].reshape(B, L, 1, 6).clone()
|
| 242 |
+
|
| 243 |
+
body6d_start = 9
|
| 244 |
+
body6d_end = body6d_start + body_n * 6
|
| 245 |
+
body_rot6d_full = latent_denorm[..., body6d_start:body6d_end].clone().reshape(B, L, body_n, 6)
|
| 246 |
+
|
| 247 |
+
# 52 joints need to be split into hands
|
| 248 |
+
left_hand_pose = right_hand_pose = None
|
| 249 |
+
if nj == 52:
|
| 250 |
+
body_rot6d = body_rot6d_full[:, :, :21, :].clone()
|
| 251 |
+
left_hand_pose = body_rot6d_full[:, :, 21:36, :].clone()
|
| 252 |
+
right_hand_pose = body_rot6d_full[:, :, 36:51, :].clone()
|
| 253 |
+
else:
|
| 254 |
+
body_rot6d = body_rot6d_full
|
| 255 |
+
|
| 256 |
+
if left_hand_pose is not None and right_hand_pose is not None:
|
| 257 |
+
body_full = torch.cat([body_rot6d, left_hand_pose, right_hand_pose], dim=2)
|
| 258 |
+
else:
|
| 259 |
+
body_full = body_rot6d
|
| 260 |
+
rot6d = torch.cat([root_rot6d, body_full], dim=2) # (B, L, nj, 6)
|
| 261 |
+
if should_apply_smooothing:
|
| 262 |
+
# only apply slerp smoothing to the first 22 joints (non-finger joints)
|
| 263 |
+
rot6d_body = rot6d[:, :, :22, :] # (B, L, 22, 6)
|
| 264 |
+
rot6d_fingers = rot6d[:, :, 22:, :] # (B, L, J-22, 6)
|
| 265 |
+
rot6d_body_smooth = self.smooth_with_slerp(rot6d_body, sigma=1.0)
|
| 266 |
+
rot6d_smooth = torch.cat([rot6d_body_smooth, rot6d_fingers], dim=2)
|
| 267 |
+
else:
|
| 268 |
+
rot6d_smooth = rot6d
|
| 269 |
+
root_rotmat_smooth = rot6d_to_rotation_matrix(rot6d_smooth[:, :, 0, :]) # (B, L, 3, 3)
|
| 270 |
+
|
| 271 |
+
transl_fixed = transl.detach()
|
| 272 |
+
if should_apply_smooothing:
|
| 273 |
+
transl_smooth = self.smooth_with_savgol(transl_fixed.detach(), window_length=11, polyorder=5)
|
| 274 |
+
else:
|
| 275 |
+
transl_smooth = transl_fixed
|
| 276 |
+
|
| 277 |
+
if self.body_model is not None:
|
| 278 |
+
print(
|
| 279 |
+
f"{self.__class__.__name__} rot6d_smooth shape: {rot6d_smooth.shape}, transl_smooth shape: {transl_smooth.shape}"
|
| 280 |
+
)
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
vertices_all = []
|
| 283 |
+
k3d_all = []
|
| 284 |
+
for bs in range(rot6d_smooth.shape[0]):
|
| 285 |
+
out = self.body_model.forward({"rot6d": rot6d_smooth[bs], "trans": transl_smooth[bs]})
|
| 286 |
+
vertices_all.append(out["vertices"])
|
| 287 |
+
k3d_all.append(out["keypoints3d"])
|
| 288 |
+
vertices = torch.stack(vertices_all, dim=0)
|
| 289 |
+
k3d = torch.stack(k3d_all, dim=0)
|
| 290 |
+
print(f"{self.__class__.__name__} vertices shape: {vertices.shape}, k3d shape: {k3d.shape}")
|
| 291 |
+
# align with the ground
|
| 292 |
+
min_y = vertices[..., 1].amin(dim=(1, 2), keepdim=True) # (B, 1, 1)
|
| 293 |
+
print(f"{self.__class__.__name__} min_y: {min_y}")
|
| 294 |
+
k3d = k3d.clone()
|
| 295 |
+
k3d[..., 1] -= min_y # (B, L, J) - (B, 1, 1)
|
| 296 |
+
transl_smooth = transl_smooth.clone()
|
| 297 |
+
transl_smooth[..., 1] -= min_y.squeeze(-1).to(device) # (B, L) - (B, 1)
|
| 298 |
+
else:
|
| 299 |
+
k3d = torch.zeros(B, L, nj, 3, device=device)
|
| 300 |
+
|
| 301 |
+
return dict(
|
| 302 |
+
latent_denorm=latent_denorm, # (B, L, 201)
|
| 303 |
+
keypoints3d=k3d, # (B, L, J, 3)
|
| 304 |
+
rot6d=rot6d_smooth, # (B, L, J, 6)
|
| 305 |
+
transl=transl_smooth, # (B, L, 3)
|
| 306 |
+
root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3)
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def smooth_with_savgol(input: torch.Tensor, window_length: int = 9, polyorder: int = 5) -> torch.Tensor:
|
| 311 |
+
if len(input.shape) == 2:
|
| 312 |
+
is_batch = False
|
| 313 |
+
input = input.unsqueeze(0)
|
| 314 |
+
else:
|
| 315 |
+
is_batch = True
|
| 316 |
+
input_np = input.cpu().numpy()
|
| 317 |
+
input_smooth_np = np.empty_like(input_np, dtype=np.float32)
|
| 318 |
+
for b in range(input_np.shape[0]):
|
| 319 |
+
for j in range(input_np.shape[2]):
|
| 320 |
+
input_smooth_np[b, :, j] = savgol_filter(input_np[b, :, j], window_length, polyorder)
|
| 321 |
+
input_smooth = torch.from_numpy(input_smooth_np).to(input)
|
| 322 |
+
if not is_batch:
|
| 323 |
+
input_smooth = input_smooth.squeeze(0)
|
| 324 |
+
return input_smooth
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def smooth_with_slerp(input: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
|
| 328 |
+
def fix_time_continuity(q: Tensor, time_dim: int = -3):
|
| 329 |
+
shape = q.shape
|
| 330 |
+
qv = q.moveaxis(time_dim, 0).contiguous().view(shape[time_dim], -1, 4)
|
| 331 |
+
qv = quaternion_fix_continuity(qv)
|
| 332 |
+
return qv.view(shape[time_dim], *shape[:time_dim], *shape[time_dim + 1 :]).moveaxis(0, time_dim)
|
| 333 |
+
|
| 334 |
+
num_joints = input.shape[2]
|
| 335 |
+
RR = rot6d_to_rotation_matrix(input)
|
| 336 |
+
qq = matrix_to_quaternion(RR)
|
| 337 |
+
qq_np = fix_time_continuity(qq, time_dim=1).cpu().numpy()
|
| 338 |
+
qq_s_np = smooth_rotation(
|
| 339 |
+
qq_np,
|
| 340 |
+
sigma=sigma,
|
| 341 |
+
)
|
| 342 |
+
input_smooth = rotation_matrix_to_rot6d(quaternion_to_matrix(torch.from_numpy(qq_s_np)))
|
| 343 |
+
return input_smooth.to(input.device)
|
| 344 |
+
|
| 345 |
+
@staticmethod
|
| 346 |
+
def noise_from_seeds(
|
| 347 |
+
latent: Tensor, seeds: Union[int, List[int]], seed_start: int = 0, random_generator_on_gpu: bool = True
|
| 348 |
+
) -> Tensor:
|
| 349 |
+
if isinstance(seeds, int):
|
| 350 |
+
seeds = list(range(seeds))
|
| 351 |
+
noise_list = []
|
| 352 |
+
B = latent.shape[0]
|
| 353 |
+
shape = (B, *latent.shape[1:])
|
| 354 |
+
for seed in seeds:
|
| 355 |
+
if random_generator_on_gpu:
|
| 356 |
+
generator = torch.Generator(device=latent.device).manual_seed(seed + seed_start)
|
| 357 |
+
noise_sample = randn_tensor(shape, generator=generator, device=latent.device, dtype=latent.dtype)
|
| 358 |
+
else:
|
| 359 |
+
generator = torch.Generator().manual_seed(seed + seed_start)
|
| 360 |
+
noise_sample = randn_tensor(shape, generator=generator, dtype=latent.dtype).to(latent.device)
|
| 361 |
+
noise_list.append(noise_sample)
|
| 362 |
+
return torch.cat(noise_list, dim=0)
|
| 363 |
+
|
| 364 |
+
def _maybe_inject_source_token(
|
| 365 |
+
self,
|
| 366 |
+
vtxt_input: Tensor,
|
| 367 |
+
ctxt_input: Tensor,
|
| 368 |
+
ctxt_mask_temporal: Tensor,
|
| 369 |
+
sources: Optional[List[str]],
|
| 370 |
+
trigger_sources: Optional[set] = None,
|
| 371 |
+
prob: float = 0.5,
|
| 372 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 373 |
+
if (sources is None or trigger_sources is None) or not self.enable_special_game_feat:
|
| 374 |
+
return vtxt_input, ctxt_input, ctxt_mask_temporal
|
| 375 |
+
|
| 376 |
+
B, Lc, Dc = ctxt_input.shape
|
| 377 |
+
assert (
|
| 378 |
+
isinstance(sources, (list, tuple)) and len(sources) == B
|
| 379 |
+
), f"sources length should be equal to batch: {len(sources)} vs {B}"
|
| 380 |
+
|
| 381 |
+
trig = set(s.lower() for s in trigger_sources)
|
| 382 |
+
src_mask = torch.tensor(
|
| 383 |
+
[str(s).lower() in trig for s in sources], dtype=torch.bool, device=ctxt_input.device
|
| 384 |
+
) # (B,)
|
| 385 |
+
if not src_mask.any():
|
| 386 |
+
return vtxt_input, ctxt_input, ctxt_mask_temporal
|
| 387 |
+
|
| 388 |
+
rand_mask = (
|
| 389 |
+
torch.rand(B, device=ctxt_input.device) < prob
|
| 390 |
+
if self.training
|
| 391 |
+
else torch.BoolTensor(B).fill_(True).to(ctxt_input.device)
|
| 392 |
+
)
|
| 393 |
+
apply_mask = src_mask & rand_mask
|
| 394 |
+
if not apply_mask.any():
|
| 395 |
+
return vtxt_input, ctxt_input, ctxt_mask_temporal
|
| 396 |
+
|
| 397 |
+
# vtxt: only add mixture to the hit samples
|
| 398 |
+
vtxt_token = self.special_game_vtxt_feat.to(vtxt_input).expand(B, 1, -1)
|
| 399 |
+
vtxt_input = vtxt_input + vtxt_token * apply_mask.view(B, 1, 1).to(vtxt_input.dtype)
|
| 400 |
+
|
| 401 |
+
# calculate the current effective length of each sample
|
| 402 |
+
if ctxt_mask_temporal.dtype == torch.bool:
|
| 403 |
+
cur_len = ctxt_mask_temporal.sum(dim=1).long() # (B,)
|
| 404 |
+
else:
|
| 405 |
+
cur_len = (ctxt_mask_temporal > 0).sum(dim=1).long()
|
| 406 |
+
|
| 407 |
+
# for the "not full" hit samples,
|
| 408 |
+
# write the special token at the cur_len position,
|
| 409 |
+
# and set the mask to True
|
| 410 |
+
can_inplace = apply_mask & (cur_len < Lc)
|
| 411 |
+
b_inplace = torch.nonzero(can_inplace, as_tuple=False).squeeze(1) # (K,)
|
| 412 |
+
if b_inplace.numel() > 0:
|
| 413 |
+
pos = cur_len[b_inplace] # (K,)
|
| 414 |
+
token = self.special_game_ctxt_feat.squeeze(0).squeeze(0).to(ctxt_input) # (Dc,)
|
| 415 |
+
ctxt_input[b_inplace, pos, :] = token.unsqueeze(0).expand(b_inplace.numel(), Dc)
|
| 416 |
+
if ctxt_mask_temporal.dtype == torch.bool:
|
| 417 |
+
ctxt_mask_temporal[b_inplace, pos] = True
|
| 418 |
+
else:
|
| 419 |
+
ctxt_mask_temporal[b_inplace, pos] = 1
|
| 420 |
+
|
| 421 |
+
# if there are "full" hit samples, need to pad one:
|
| 422 |
+
# the full samples write the special token at the new position,
|
| 423 |
+
# other samples pad zero and mask=False
|
| 424 |
+
need_expand = (apply_mask & (cur_len >= Lc)).any()
|
| 425 |
+
if need_expand:
|
| 426 |
+
suffix = torch.zeros((B, 1, Dc), dtype=ctxt_input.dtype, device=ctxt_input.device)
|
| 427 |
+
full_hit = apply_mask & (cur_len >= Lc)
|
| 428 |
+
b_full = torch.nonzero(full_hit, as_tuple=False).squeeze(1)
|
| 429 |
+
if b_full.numel() > 0:
|
| 430 |
+
suffix[b_full, 0, :] = (
|
| 431 |
+
self.special_game_ctxt_feat.expand(b_full.numel(), 1, -1).to(ctxt_input).squeeze(1)
|
| 432 |
+
)
|
| 433 |
+
ctxt_input = torch.cat([ctxt_input, suffix], dim=1)
|
| 434 |
+
|
| 435 |
+
if ctxt_mask_temporal.dtype == torch.bool:
|
| 436 |
+
suffix_mask = torch.zeros((B, 1), dtype=torch.bool, device=ctxt_input.device)
|
| 437 |
+
suffix_mask[b_full, 0] = True
|
| 438 |
+
else:
|
| 439 |
+
suffix_mask = torch.zeros((B, 1), dtype=ctxt_mask_temporal.dtype, device=ctxt_input.device)
|
| 440 |
+
suffix_mask[b_full, 0] = 1
|
| 441 |
+
ctxt_mask_temporal = torch.cat([ctxt_mask_temporal, suffix_mask], dim=1)
|
| 442 |
+
|
| 443 |
+
return vtxt_input, ctxt_input, ctxt_mask_temporal
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MotionFlowMatching(MotionGeneration):
|
| 447 |
+
def __init__(
|
| 448 |
+
self,
|
| 449 |
+
network_module: str,
|
| 450 |
+
network_module_args: dict,
|
| 451 |
+
text_encoder_module: str,
|
| 452 |
+
text_encoder_cfg: dict,
|
| 453 |
+
noise_scheduler_cfg: dict = {"method": "euler"},
|
| 454 |
+
infer_noise_scheduler_cfg: dict = {"validation_steps": 50},
|
| 455 |
+
mean_std_dir: Optional[str] = None,
|
| 456 |
+
losses_cfg: Optional[dict] = None,
|
| 457 |
+
train_cfg: Optional[dict] = None,
|
| 458 |
+
test_cfg: Optional[dict] = None,
|
| 459 |
+
**kwargs,
|
| 460 |
+
):
|
| 461 |
+
super().__init__(
|
| 462 |
+
network_module=network_module,
|
| 463 |
+
network_module_args=network_module_args,
|
| 464 |
+
text_encoder_module=text_encoder_module,
|
| 465 |
+
text_encoder_cfg=text_encoder_cfg,
|
| 466 |
+
losses_cfg=losses_cfg,
|
| 467 |
+
mean_std_dir=(mean_std_dir if mean_std_dir is not None else test_cfg.get("mean_std_dir", None)),
|
| 468 |
+
**kwargs,
|
| 469 |
+
)
|
| 470 |
+
# build scheduler
|
| 471 |
+
self._noise_scheduler_cfg = deepcopy(noise_scheduler_cfg)
|
| 472 |
+
self._infer_noise_scheduler_cfg = deepcopy(infer_noise_scheduler_cfg)
|
| 473 |
+
# additional cfg
|
| 474 |
+
self.train_cfg = deepcopy(train_cfg) if train_cfg else dict()
|
| 475 |
+
self.test_cfg = deepcopy(test_cfg) if test_cfg else dict()
|
| 476 |
+
self._parse_test_cfg()
|
| 477 |
+
|
| 478 |
+
def _parse_test_cfg(self) -> None:
|
| 479 |
+
self.validation_steps = self._infer_noise_scheduler_cfg["validation_steps"]
|
| 480 |
+
self.text_guidance_scale = self.test_cfg.get("text_guidance_scale", 1)
|
| 481 |
+
|
| 482 |
+
@torch.no_grad()
|
| 483 |
+
def generate(
|
| 484 |
+
self,
|
| 485 |
+
text: Union[str, List[str]],
|
| 486 |
+
seed_input: List[int],
|
| 487 |
+
duration_slider: int,
|
| 488 |
+
cfg_scale: Optional[float] = None,
|
| 489 |
+
use_special_game_feat: bool = False,
|
| 490 |
+
hidden_state_dict=None,
|
| 491 |
+
length=None,
|
| 492 |
+
) -> Dict[str, Any]:
|
| 493 |
+
device = get_module_device(self)
|
| 494 |
+
if length is None:
|
| 495 |
+
length = int(round(duration_slider * self.output_mesh_fps))
|
| 496 |
+
assert (
|
| 497 |
+
0 < length < 5000
|
| 498 |
+
), f"input duration_slider must be in (0, {5000/self.output_mesh_fps}] due to rope, but got {duration_slider}"
|
| 499 |
+
if length > self.train_frames or length < min(self.train_frames, 20):
|
| 500 |
+
print(f">>> given length is too long or too short, got {length}, will be truncated")
|
| 501 |
+
length = min(length, self.train_frames)
|
| 502 |
+
length = max(length, min(self.train_frames, 20))
|
| 503 |
+
|
| 504 |
+
repeat = len(seed_input)
|
| 505 |
+
if isinstance(text, list):
|
| 506 |
+
assert len(text) == repeat, f"len(text) must equal len(seed_input), got {len(text)} vs {repeat}"
|
| 507 |
+
text_list = text
|
| 508 |
+
elif isinstance(text, str):
|
| 509 |
+
text_list = [text] * repeat
|
| 510 |
+
else:
|
| 511 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
| 512 |
+
|
| 513 |
+
if not self.uncondition_mode:
|
| 514 |
+
if hidden_state_dict is None:
|
| 515 |
+
hidden_state_dict = self.encode_text({"text": text_list})
|
| 516 |
+
vtxt_input = hidden_state_dict["text_vec_raw"]
|
| 517 |
+
ctxt_input = hidden_state_dict["text_ctxt_raw"]
|
| 518 |
+
ctxt_length = hidden_state_dict["text_ctxt_raw_length"]
|
| 519 |
+
# check shape
|
| 520 |
+
if len(vtxt_input.shape) == 2 and len(ctxt_input.shape) == 2:
|
| 521 |
+
vtxt_input = vtxt_input[None].repeat(repeat, 1, 1)
|
| 522 |
+
ctxt_input = ctxt_input[None].repeat(repeat, 1, 1)
|
| 523 |
+
ctxt_length = ctxt_length.repeat(repeat)
|
| 524 |
+
ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
|
| 525 |
+
sources = None if not use_special_game_feat else ["Game"] * repeat
|
| 526 |
+
vtxt_input, ctxt_input, ctxt_mask_temporal = self._maybe_inject_source_token(
|
| 527 |
+
vtxt_input, ctxt_input, ctxt_mask_temporal, sources, trigger_sources={"Taobao", "Game"}
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
vtxt_input = self.null_vtxt_feat.expand(repeat, 1, -1)
|
| 531 |
+
ctxt_input = self.null_ctxt_input.expand(repeat, 1, -1)
|
| 532 |
+
ctxt_length = torch.tensor([1]).expand(repeat)
|
| 533 |
+
ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]).expand(repeat, -1)
|
| 534 |
+
assert len(vtxt_input.shape) == 3, f"vtxt_input.shape: {vtxt_input.shape}, should be (B, 1, D)"
|
| 535 |
+
assert len(ctxt_input.shape) == 3, f"ctxt_input.shape: {ctxt_input.shape}, should be (B, 1, D)"
|
| 536 |
+
assert len(ctxt_length.shape) == 1, f"ctxt_length.shape: {ctxt_length.shape}, should be (B,)"
|
| 537 |
+
|
| 538 |
+
ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
|
| 539 |
+
x_length = torch.LongTensor([length] * repeat).to(device)
|
| 540 |
+
x_mask_temporal = length_to_mask(x_length, self.train_frames)
|
| 541 |
+
|
| 542 |
+
text_guidance_scale = cfg_scale if cfg_scale is not None else self.text_guidance_scale
|
| 543 |
+
do_classifier_free_guidance = text_guidance_scale > 1.0 and not self.uncondition_mode
|
| 544 |
+
if do_classifier_free_guidance is True:
|
| 545 |
+
silent_text_feat = self.null_vtxt_feat.expand(*vtxt_input.shape)
|
| 546 |
+
vtxt_input = torch.cat([silent_text_feat, vtxt_input], dim=0)
|
| 547 |
+
|
| 548 |
+
if self.enable_ctxt_null_feat:
|
| 549 |
+
silent_ctxt_input = self.null_ctxt_input.expand(*ctxt_input.shape)
|
| 550 |
+
else:
|
| 551 |
+
silent_ctxt_input = ctxt_input
|
| 552 |
+
ctxt_input = torch.cat([silent_ctxt_input, ctxt_input], dim=0)
|
| 553 |
+
|
| 554 |
+
ctxt_mask_temporal = torch.cat([ctxt_mask_temporal] * 2, dim=0)
|
| 555 |
+
x_mask_temporal = torch.cat([x_mask_temporal] * 2, dim=0)
|
| 556 |
+
|
| 557 |
+
def fn(t: Tensor, x: Tensor) -> Tensor:
|
| 558 |
+
# predict flow
|
| 559 |
+
x_input = torch.cat([x] * 2, dim=0) if do_classifier_free_guidance else x
|
| 560 |
+
x_pred = self.motion_transformer(
|
| 561 |
+
x=x_input,
|
| 562 |
+
ctxt_input=ctxt_input,
|
| 563 |
+
vtxt_input=vtxt_input,
|
| 564 |
+
timesteps=t.expand(x_input.shape[0]),
|
| 565 |
+
x_mask_temporal=x_mask_temporal,
|
| 566 |
+
ctxt_mask_temporal=ctxt_mask_temporal,
|
| 567 |
+
)
|
| 568 |
+
if do_classifier_free_guidance:
|
| 569 |
+
x_pred_basic, x_pred_text = x_pred.chunk(2, dim=0)
|
| 570 |
+
x_pred = x_pred_basic + text_guidance_scale * (x_pred_text - x_pred_basic)
|
| 571 |
+
return x_pred
|
| 572 |
+
|
| 573 |
+
# duplicate test corner for inner time step oberservation
|
| 574 |
+
t = torch.linspace(0, 1, self.validation_steps + 1, device=device)
|
| 575 |
+
y0 = self.noise_from_seeds(
|
| 576 |
+
torch.zeros(
|
| 577 |
+
1,
|
| 578 |
+
self.train_frames,
|
| 579 |
+
self._network_module_args["input_dim"],
|
| 580 |
+
device=device,
|
| 581 |
+
),
|
| 582 |
+
seed_input,
|
| 583 |
+
random_generator_on_gpu=self.random_generator_on_gpu,
|
| 584 |
+
)
|
| 585 |
+
with torch.no_grad():
|
| 586 |
+
trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
|
| 587 |
+
sampled = trajectory[-1]
|
| 588 |
+
assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
|
| 589 |
+
sampled = sampled[:, :length, ...].clone()
|
| 590 |
+
|
| 591 |
+
output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
|
| 592 |
+
|
| 593 |
+
return {
|
| 594 |
+
**output_dict,
|
| 595 |
+
"text": text,
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
if __name__ == "__main__":
|
| 600 |
+
# python -m hymotion.pipeline.motion_diffusion
|
| 601 |
+
import time
|
| 602 |
+
|
| 603 |
+
import torch
|
| 604 |
+
|
| 605 |
+
device = "cuda:0"
|
| 606 |
+
bsz, input_dim = 64, 272
|
| 607 |
+
seq_lens = [90, 180, 360]
|
| 608 |
+
ctxt_seq_lens = 64
|
| 609 |
+
warmup = 5
|
| 610 |
+
repeats = 100
|
| 611 |
+
|
| 612 |
+
network_module = "hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT"
|
| 613 |
+
network_module_args = {
|
| 614 |
+
"input_dim": input_dim,
|
| 615 |
+
"feat_dim": 512,
|
| 616 |
+
"ctxt_input_dim": 4096,
|
| 617 |
+
"vtxt_input_dim": 768,
|
| 618 |
+
"num_layers": 12,
|
| 619 |
+
"num_heads": 4,
|
| 620 |
+
"mlp_ratio": 2.0,
|
| 621 |
+
"dropout": 0.0,
|
| 622 |
+
"mask_mode": "narrowband",
|
| 623 |
+
}
|
| 624 |
+
text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
|
| 625 |
+
text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": ctxt_seq_lens}
|
| 626 |
+
|
| 627 |
+
# ================================ FM_MMDiT ================================
|
| 628 |
+
FM_MMDiT = MotionFlowMatching(
|
| 629 |
+
network_module=network_module,
|
| 630 |
+
network_module_args=network_module_args,
|
| 631 |
+
text_encoder_module=text_encoder_module,
|
| 632 |
+
text_encoder_cfg=text_encoder_cfg,
|
| 633 |
+
noise_scheduler_module={"method": "euler"},
|
| 634 |
+
infer_noise_scheduler_cfg={"validation_steps": 50},
|
| 635 |
+
train_cfg={"cond_mask_prob": 0.1},
|
| 636 |
+
test_cfg={
|
| 637 |
+
"text_guidance_scale": 1.5,
|
| 638 |
+
},
|
| 639 |
+
).to(device)
|
hymotion/prompt_engineering/model_constants.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = [
|
| 2 |
+
"REWRITE_AND_INFER_TIME_PROMPT_FORMAT",
|
| 3 |
+
]
|
| 4 |
+
|
| 5 |
+
REWRITE_AND_INFER_TIME_PROMPT_FORMAT = """
|
| 6 |
+
# Role
|
| 7 |
+
You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems.
|
| 8 |
+
|
| 9 |
+
# Task
|
| 10 |
+
Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption.
|
| 11 |
+
|
| 12 |
+
# Instructions
|
| 13 |
+
|
| 14 |
+
### 1. Duration Estimation (frame_count)
|
| 15 |
+
- Analyze the complexity, speed, and physical constraints of the described action.
|
| 16 |
+
- Estimate the time required to perform the action in a **smooth, natural, and realistic manner**.
|
| 17 |
+
- Calculate the total duration in frames based on a **30 fps** (frames per second) standard.
|
| 18 |
+
- Output strictly as an Integer.
|
| 19 |
+
|
| 20 |
+
### 2. Caption Refinement (short_caption)
|
| 21 |
+
- Generate a refined, grammatically correct version of the input description in **English**.
|
| 22 |
+
- **Strict Constraints**:
|
| 23 |
+
- You must **PRESERVE** the original sequence of events (chronological order).
|
| 24 |
+
- You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly").
|
| 25 |
+
- **DO NOT** add new sub-actions or hallucinate details not present in the input.
|
| 26 |
+
- **DO NOT** delete any specific movements.
|
| 27 |
+
- The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request.
|
| 28 |
+
|
| 29 |
+
### 3. Output Format
|
| 30 |
+
- Return **ONLY** a raw JSON object.
|
| 31 |
+
- Do not use Markdown formatting (i.e., do not use ```json ... ```).
|
| 32 |
+
- Ensure the JSON is valid and parsable.
|
| 33 |
+
|
| 34 |
+
# JSON Structure
|
| 35 |
+
{{
|
| 36 |
+
"duration": <Integer, frames at 30fps>,
|
| 37 |
+
"short_caption": "<String, the refined English description>"
|
| 38 |
+
}}
|
| 39 |
+
|
| 40 |
+
# Input
|
| 41 |
+
{}
|
| 42 |
+
"""
|
hymotion/prompt_engineering/prompt_rewrite.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# prompt_rewrite.py
|
| 2 |
+
import base64
|
| 3 |
+
import concurrent.futures
|
| 4 |
+
import datetime
|
| 5 |
+
import hashlib
|
| 6 |
+
import hmac
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
import time
|
| 12 |
+
import uuid
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
from openai import OpenAI
|
| 17 |
+
from requests import exceptions as req_exc
|
| 18 |
+
|
| 19 |
+
from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT
|
| 20 |
+
|
| 21 |
+
# logging.basicConfig(level=logging.INFO)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class ApiConfig:
|
| 26 |
+
host: str
|
| 27 |
+
user: str
|
| 28 |
+
apikey: str
|
| 29 |
+
model: str
|
| 30 |
+
api_version: Optional[str] = None
|
| 31 |
+
timeout: int = 3600
|
| 32 |
+
source: str = "hymotion"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class RetryConfig:
|
| 37 |
+
max_retries: int = 20
|
| 38 |
+
base_delay: float = 1.0
|
| 39 |
+
timeout: float = 30.0
|
| 40 |
+
retry_status: Tuple[int, ...] = (429, 500, 502, 503, 504)
|
| 41 |
+
max_delay: float = 1.0
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ApiError(Exception):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ResponseParseError(Exception):
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class OpenAIChatApi:
|
| 53 |
+
def __init__(self, config: ApiConfig) -> None:
|
| 54 |
+
self.logger = logging.getLogger(__name__)
|
| 55 |
+
self.config = config
|
| 56 |
+
self.client = OpenAI(
|
| 57 |
+
api_key=self.config.apikey,
|
| 58 |
+
base_url=self.config.host,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def call_data_eval(self, data: Union[str, Dict[str, Any]]):
|
| 62 |
+
if isinstance(data, dict) and "messages" in data:
|
| 63 |
+
raw_msgs = data["messages"]
|
| 64 |
+
messages: List[Dict[str, str]] = []
|
| 65 |
+
for m in raw_msgs:
|
| 66 |
+
role = m.get("role", "user")
|
| 67 |
+
content = m.get("content", "")
|
| 68 |
+
if isinstance(content, list):
|
| 69 |
+
parts = []
|
| 70 |
+
for p in content:
|
| 71 |
+
if isinstance(p, dict) and ("text" in p):
|
| 72 |
+
parts.append(str(p.get("text", "")))
|
| 73 |
+
content = " ".join([t for t in parts if t])
|
| 74 |
+
elif not isinstance(content, str):
|
| 75 |
+
content = str(content)
|
| 76 |
+
messages.append({"role": role, "content": content})
|
| 77 |
+
payload = {"model": self.config.model, "messages": messages}
|
| 78 |
+
for k in (
|
| 79 |
+
"temperature",
|
| 80 |
+
"top_p",
|
| 81 |
+
"max_tokens",
|
| 82 |
+
"n",
|
| 83 |
+
"stop",
|
| 84 |
+
"presence_penalty",
|
| 85 |
+
"frequency_penalty",
|
| 86 |
+
"user",
|
| 87 |
+
):
|
| 88 |
+
if k in data:
|
| 89 |
+
payload[k] = data[k]
|
| 90 |
+
else:
|
| 91 |
+
payload = {"model": self.config.model, "messages": [{"role": "user", "content": str(data)}]}
|
| 92 |
+
try:
|
| 93 |
+
resp = self.client.chat.completions.create(**payload)
|
| 94 |
+
return resp
|
| 95 |
+
except Exception as e:
|
| 96 |
+
self.logger.error(f"OpenAI API call failed: {e}")
|
| 97 |
+
raise ApiError(f"OpenAI API call failed: {e}") from e
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ResponseParser:
|
| 101 |
+
def __init__(self):
|
| 102 |
+
self.logger = logging.getLogger(__name__)
|
| 103 |
+
|
| 104 |
+
def call_data_eval_with_retry(
|
| 105 |
+
self, api: Union[OpenAIChatApi], data: str, retry_config: Optional[RetryConfig] = None
|
| 106 |
+
) -> Tuple[Union[Dict[str, Any], int], float, float]:
|
| 107 |
+
if retry_config is None:
|
| 108 |
+
retry_config = RetryConfig()
|
| 109 |
+
|
| 110 |
+
last_error = None
|
| 111 |
+
for attempt in range(retry_config.max_retries):
|
| 112 |
+
start_time = time.time()
|
| 113 |
+
cost = 0.0
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
result = self._execute_request(api, data)
|
| 117 |
+
end_time = time.time()
|
| 118 |
+
parsed_result = self._parse_answer(result)
|
| 119 |
+
self._validate_result(parsed_result)
|
| 120 |
+
return parsed_result, cost, end_time - start_time
|
| 121 |
+
|
| 122 |
+
except (
|
| 123 |
+
concurrent.futures.TimeoutError,
|
| 124 |
+
req_exc.RequestException,
|
| 125 |
+
json.JSONDecodeError,
|
| 126 |
+
ValueError,
|
| 127 |
+
TypeError,
|
| 128 |
+
ResponseParseError,
|
| 129 |
+
) as e:
|
| 130 |
+
last_error = e
|
| 131 |
+
self.logger.warning(f"Attempt {attempt + 1} failed: {e}")
|
| 132 |
+
if isinstance(e, req_exc.RequestException) and hasattr(e, "response"):
|
| 133 |
+
if e.response is not None and e.response.status_code not in retry_config.retry_status:
|
| 134 |
+
raise ApiError(f"Non-retryable error: {e.response.status_code}") from e
|
| 135 |
+
if attempt < retry_config.max_retries - 1:
|
| 136 |
+
delay = self._calculate_delay(attempt, retry_config)
|
| 137 |
+
self.logger.info(f"JSON parsing failed, {delay:.1f} seconds later retry...")
|
| 138 |
+
time.sleep(delay)
|
| 139 |
+
|
| 140 |
+
raise ApiError(f"Retry {retry_config.max_retries} times but still failed") from last_error
|
| 141 |
+
|
| 142 |
+
def _execute_request(self, api: Union[OpenAIChatApi], data: str) -> Dict[str, Any]:
|
| 143 |
+
response = api.call_data_eval(data)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
if hasattr(response, "model_dump"):
|
| 147 |
+
return response.model_dump()
|
| 148 |
+
if isinstance(response, dict):
|
| 149 |
+
return response
|
| 150 |
+
if hasattr(response, "__dict__"):
|
| 151 |
+
return json.loads(json.dumps(response.__dict__, default=str))
|
| 152 |
+
except Exception as e:
|
| 153 |
+
raise ResponseParseError(f"Unable to parse OpenAI returned object: {type(response)} - {e}") from e
|
| 154 |
+
|
| 155 |
+
raise ResponseParseError(f"Unknown response type: {type(response)}")
|
| 156 |
+
|
| 157 |
+
def _extract_cost(self, payload: Dict[str, Any]) -> float:
|
| 158 |
+
try:
|
| 159 |
+
return float(payload.get("cost_info", {}).get("cost", 0)) / 1e6
|
| 160 |
+
except (AttributeError, KeyError):
|
| 161 |
+
return 0.0
|
| 162 |
+
|
| 163 |
+
def _validate_result(self, result: Union[Dict[str, Any], int]) -> None:
|
| 164 |
+
if isinstance(result, int):
|
| 165 |
+
return
|
| 166 |
+
elif isinstance(result, dict):
|
| 167 |
+
required_fields = ["duration", "short_caption"]
|
| 168 |
+
for field in required_fields:
|
| 169 |
+
if not isinstance(result.get(field), (int, str)):
|
| 170 |
+
raise ResponseParseError(f"LLM returned invalid format: {field}")
|
| 171 |
+
else:
|
| 172 |
+
raise ResponseParseError(f"Unsupported answer type: {type(result)}")
|
| 173 |
+
|
| 174 |
+
def _calculate_delay(self, attempt: int, config: RetryConfig) -> float:
|
| 175 |
+
delay = config.base_delay * (2**attempt) * (0.5 + random.random())
|
| 176 |
+
return min(delay, config.max_delay)
|
| 177 |
+
|
| 178 |
+
def _parse_answer(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 179 |
+
if isinstance(payload, dict) and "choices" in payload:
|
| 180 |
+
return self._parse_from_choices_field(payload)
|
| 181 |
+
|
| 182 |
+
raise ResponseParseError("Unknown response format: expected choices")
|
| 183 |
+
|
| 184 |
+
def _parse_from_choices_field(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 185 |
+
choices = payload.get("choices") or []
|
| 186 |
+
if not choices:
|
| 187 |
+
raise ResponseParseError("OpenAI returned empty")
|
| 188 |
+
|
| 189 |
+
content = self._extract_content_from_choice(choices[0])
|
| 190 |
+
|
| 191 |
+
if not isinstance(content, str) or not content.strip():
|
| 192 |
+
raise ResponseParseError("OpenAI returned no valid content")
|
| 193 |
+
|
| 194 |
+
return self._parse_json_content(content)
|
| 195 |
+
|
| 196 |
+
def _extract_content_from_choice(self, choice: Any) -> Optional[str]:
|
| 197 |
+
content = None
|
| 198 |
+
|
| 199 |
+
if isinstance(choice, dict):
|
| 200 |
+
# Try message content first
|
| 201 |
+
msg = choice.get("message") or {}
|
| 202 |
+
content = msg.get("content")
|
| 203 |
+
# Fallback to delta content or text
|
| 204 |
+
if content is None:
|
| 205 |
+
delta = choice.get("delta") or {}
|
| 206 |
+
content = delta.get("content", choice.get("text"))
|
| 207 |
+
else:
|
| 208 |
+
# Handle object-like choice (e.g. Pydantic model)
|
| 209 |
+
msg = getattr(choice, "message", None)
|
| 210 |
+
if msg is not None:
|
| 211 |
+
content = getattr(msg, "content", None)
|
| 212 |
+
|
| 213 |
+
if content is None:
|
| 214 |
+
delta = getattr(choice, "delta", None)
|
| 215 |
+
if delta is not None:
|
| 216 |
+
content = getattr(delta, "content", None)
|
| 217 |
+
|
| 218 |
+
if content is None:
|
| 219 |
+
content = getattr(choice, "text", None)
|
| 220 |
+
|
| 221 |
+
return content
|
| 222 |
+
|
| 223 |
+
def _parse_json_content(self, content: str) -> Dict[str, Any]:
|
| 224 |
+
cleaned = self._cleanup_fenced_json(content)
|
| 225 |
+
try:
|
| 226 |
+
return json.loads(cleaned)
|
| 227 |
+
except json.JSONDecodeError as e:
|
| 228 |
+
self.logger.warning(f"JSON parsing failed, original content: {cleaned[:500]}...")
|
| 229 |
+
raise ResponseParseError(f"JSON parsing failed: {e}") from e
|
| 230 |
+
|
| 231 |
+
def _cleanup_fenced_json(self, text: str) -> str:
|
| 232 |
+
text = text.strip()
|
| 233 |
+
if text.startswith("```"):
|
| 234 |
+
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 235 |
+
text = re.sub(r"\s*```$", "", text)
|
| 236 |
+
if not text.lstrip().startswith("{") and "{" in text and "}" in text:
|
| 237 |
+
start = text.find("{")
|
| 238 |
+
end = text.rfind("}")
|
| 239 |
+
if 0 <= start < end:
|
| 240 |
+
text = text[start : end + 1]
|
| 241 |
+
return text
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class PromptRewriter:
|
| 245 |
+
def __init__(self, host: Optional[str] = None, parser: Optional[ResponseParser] = None):
|
| 246 |
+
self.parser = parser or ResponseParser()
|
| 247 |
+
self.logger = logging.getLogger(__name__)
|
| 248 |
+
self.api = OpenAIChatApi(
|
| 249 |
+
ApiConfig(
|
| 250 |
+
host=host,
|
| 251 |
+
user="",
|
| 252 |
+
apikey="EMPTY",
|
| 253 |
+
model="Qwen3-30B-A3B-SFT",
|
| 254 |
+
api_version="",
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
def rewrite_prompt_and_infer_time(
|
| 259 |
+
self,
|
| 260 |
+
text: str,
|
| 261 |
+
prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT,
|
| 262 |
+
retry_config: Optional[RetryConfig] = None,
|
| 263 |
+
) -> Tuple[float, str]:
|
| 264 |
+
self.logger.info("Start rewriting prompt...")
|
| 265 |
+
try:
|
| 266 |
+
result, cost, elapsed = self.parser.call_data_eval_with_retry(
|
| 267 |
+
self.api, prompt_format.format(text), retry_config
|
| 268 |
+
)
|
| 269 |
+
self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s")
|
| 270 |
+
return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
|
| 271 |
+
|
| 272 |
+
except Exception as e:
|
| 273 |
+
self.logger.error(f"Prompt rewriting failed: {e}")
|
| 274 |
+
raise
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
# python -m hymotion.prompt_engineering.prompt_rewrite
|
| 279 |
+
|
| 280 |
+
logging.basicConfig(level=logging.INFO)
|
| 281 |
+
text = "person jumps after they runs"
|
| 282 |
+
prompt_rewriter = PromptRewriter()
|
| 283 |
+
result = prompt_rewriter.rewrite_prompt_and_infer_time(text)
|
| 284 |
+
print(result)
|
hymotion/utils/configs.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import copy
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import platform
|
| 5 |
+
import re
|
| 6 |
+
import shutil
|
| 7 |
+
import sys
|
| 8 |
+
import tempfile
|
| 9 |
+
import types
|
| 10 |
+
import uuid
|
| 11 |
+
from importlib import import_module
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, Iterator, NoReturn, Optional, Union
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
from .misc import import_modules_from_strings
|
| 17 |
+
from .path import check_file_exist
|
| 18 |
+
|
| 19 |
+
BASE_KEY = "_base_"
|
| 20 |
+
DELETE_KEY = "_delete_"
|
| 21 |
+
RESERVED_KEYS = ["filename", "text", "pretty_text"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Config:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
cfg_dict: Optional[dict] = None,
|
| 28 |
+
cfg_text: Optional[str] = None,
|
| 29 |
+
filename: Optional[str] = None,
|
| 30 |
+
) -> None:
|
| 31 |
+
if cfg_dict is None:
|
| 32 |
+
cfg_dict = dict()
|
| 33 |
+
elif not isinstance(cfg_dict, dict):
|
| 34 |
+
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
|
| 35 |
+
for key in cfg_dict:
|
| 36 |
+
if key in RESERVED_KEYS:
|
| 37 |
+
raise KeyError(f"{key} is reserved for config file")
|
| 38 |
+
|
| 39 |
+
if isinstance(filename, Path):
|
| 40 |
+
filename = str(filename)
|
| 41 |
+
|
| 42 |
+
super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
|
| 43 |
+
super(Config, self).__setattr__("_filename", filename)
|
| 44 |
+
if cfg_text:
|
| 45 |
+
text = cfg_text
|
| 46 |
+
elif filename:
|
| 47 |
+
with open(filename, "r") as f:
|
| 48 |
+
text = f.read()
|
| 49 |
+
else:
|
| 50 |
+
text = ""
|
| 51 |
+
super(Config, self).__setattr__("_text", text)
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def fromfile(
|
| 55 |
+
filename: str,
|
| 56 |
+
use_predefined_variables: bool = True,
|
| 57 |
+
import_custom_modules: bool = True,
|
| 58 |
+
) -> "Config":
|
| 59 |
+
if isinstance(filename, Path):
|
| 60 |
+
filename = str(filename)
|
| 61 |
+
cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
|
| 62 |
+
if import_custom_modules and cfg_dict.get("custom_imports", None):
|
| 63 |
+
import_modules_from_strings(**cfg_dict["custom_imports"])
|
| 64 |
+
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def _file2dict(filename: str, use_predefined_variables: bool = True) -> tuple[dict, str]:
|
| 68 |
+
filename = osp.abspath(osp.expanduser(filename))
|
| 69 |
+
check_file_exist(filename)
|
| 70 |
+
fileExtname = osp.splitext(filename)[1]
|
| 71 |
+
if fileExtname not in [".py"]:
|
| 72 |
+
raise IOError("Only py type are supported now!")
|
| 73 |
+
|
| 74 |
+
cfg_dict = {}
|
| 75 |
+
|
| 76 |
+
with tempfile.TemporaryDirectory() as temp_config_dir:
|
| 77 |
+
temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname)
|
| 78 |
+
if platform.system() == "Windows":
|
| 79 |
+
temp_config_file.close()
|
| 80 |
+
temp_config_name = osp.basename(temp_config_file.name)
|
| 81 |
+
# Substitute predefined variables
|
| 82 |
+
if use_predefined_variables:
|
| 83 |
+
Config._substitute_predefined_vars(filename, temp_config_file.name)
|
| 84 |
+
else:
|
| 85 |
+
shutil.copyfile(filename, temp_config_file.name)
|
| 86 |
+
# Substitute base variables from placeholders to strings
|
| 87 |
+
base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name)
|
| 88 |
+
|
| 89 |
+
if filename.endswith(".py"):
|
| 90 |
+
temp_module_name = osp.splitext(temp_config_name)[0]
|
| 91 |
+
sys.path.insert(0, temp_config_dir)
|
| 92 |
+
Config._validate_py_syntax(filename)
|
| 93 |
+
mod = import_module(temp_module_name)
|
| 94 |
+
sys.path.pop(0)
|
| 95 |
+
cfg_dict = {
|
| 96 |
+
name: value
|
| 97 |
+
for name, value in mod.__dict__.items()
|
| 98 |
+
if not name.startswith("__")
|
| 99 |
+
and not isinstance(value, types.ModuleType)
|
| 100 |
+
and not isinstance(value, types.FunctionType)
|
| 101 |
+
}
|
| 102 |
+
# delete imported module
|
| 103 |
+
del sys.modules[temp_module_name]
|
| 104 |
+
|
| 105 |
+
# close temp file
|
| 106 |
+
temp_config_file.close()
|
| 107 |
+
|
| 108 |
+
cfg_text = filename + "\n"
|
| 109 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 110 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
| 111 |
+
cfg_text += f.read()
|
| 112 |
+
|
| 113 |
+
if BASE_KEY in cfg_dict:
|
| 114 |
+
cfg_dir = osp.dirname(filename)
|
| 115 |
+
base_filename = cfg_dict.pop(BASE_KEY)
|
| 116 |
+
base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
|
| 117 |
+
|
| 118 |
+
cfg_dict_list = list()
|
| 119 |
+
cfg_text_list = list()
|
| 120 |
+
for f in base_filename:
|
| 121 |
+
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
| 122 |
+
cfg_dict_list.append(_cfg_dict)
|
| 123 |
+
cfg_text_list.append(_cfg_text)
|
| 124 |
+
|
| 125 |
+
base_cfg_dict = dict()
|
| 126 |
+
for c in cfg_dict_list:
|
| 127 |
+
duplicate_keys = base_cfg_dict.keys() & c.keys()
|
| 128 |
+
if len(duplicate_keys) > 0:
|
| 129 |
+
raise KeyError("Duplicate key is not allowed among bases. " f"Duplicate keys: {duplicate_keys}")
|
| 130 |
+
base_cfg_dict.update(c)
|
| 131 |
+
|
| 132 |
+
# Substitute base variables from strings to their actual values
|
| 133 |
+
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict)
|
| 134 |
+
assert isinstance(cfg_dict, dict)
|
| 135 |
+
|
| 136 |
+
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
| 137 |
+
cfg_dict = base_cfg_dict
|
| 138 |
+
|
| 139 |
+
# merge cfg_text
|
| 140 |
+
cfg_text_list.append(cfg_text)
|
| 141 |
+
cfg_text = "\n".join(cfg_text_list)
|
| 142 |
+
|
| 143 |
+
return cfg_dict, cfg_text
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def _validate_py_syntax(filename: str) -> None:
|
| 147 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 148 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
| 149 |
+
content = f.read()
|
| 150 |
+
try:
|
| 151 |
+
ast.parse(content)
|
| 152 |
+
except SyntaxError as e:
|
| 153 |
+
raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}")
|
| 154 |
+
|
| 155 |
+
@staticmethod
|
| 156 |
+
def _pre_substitute_base_vars(filename: str, temp_config_name: str) -> dict:
|
| 157 |
+
"""Substitute base variable placehoders to string, so that parsing would work."""
|
| 158 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 159 |
+
config_file = f.read()
|
| 160 |
+
base_var_dict = {}
|
| 161 |
+
regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
|
| 162 |
+
base_vars = set(re.findall(regexp, config_file))
|
| 163 |
+
for base_var in base_vars:
|
| 164 |
+
randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
|
| 165 |
+
base_var_dict[randstr] = base_var
|
| 166 |
+
regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
|
| 167 |
+
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
| 168 |
+
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
| 169 |
+
tmp_config_file.write(config_file)
|
| 170 |
+
return base_var_dict
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _substitute_base_vars(
|
| 174 |
+
cfg: Union[dict, list, tuple, str],
|
| 175 |
+
base_var_dict: dict,
|
| 176 |
+
base_cfg: dict,
|
| 177 |
+
) -> Union[dict, list, tuple, str]:
|
| 178 |
+
"""Substitute variable strings to their actual values."""
|
| 179 |
+
cfg = copy.deepcopy(cfg)
|
| 180 |
+
|
| 181 |
+
if isinstance(cfg, dict):
|
| 182 |
+
for k, v in cfg.items():
|
| 183 |
+
if isinstance(v, str) and v in base_var_dict:
|
| 184 |
+
new_v = base_cfg
|
| 185 |
+
for new_k in base_var_dict[v].split("."):
|
| 186 |
+
new_v = new_v[new_k]
|
| 187 |
+
cfg[k] = new_v
|
| 188 |
+
elif isinstance(v, (list, tuple, dict)):
|
| 189 |
+
cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
|
| 190 |
+
elif isinstance(cfg, tuple):
|
| 191 |
+
cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg)
|
| 192 |
+
elif isinstance(cfg, list):
|
| 193 |
+
cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg]
|
| 194 |
+
elif isinstance(cfg, str) and cfg in base_var_dict:
|
| 195 |
+
new_v = base_cfg
|
| 196 |
+
for new_k in base_var_dict[cfg].split("."):
|
| 197 |
+
new_v = new_v[new_k]
|
| 198 |
+
cfg = new_v
|
| 199 |
+
|
| 200 |
+
return cfg
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def _substitute_predefined_vars(filename: str, temp_config_name: str) -> None:
|
| 204 |
+
file_dirname = osp.dirname(filename)
|
| 205 |
+
file_basename = osp.basename(filename)
|
| 206 |
+
file_basename_no_extension = osp.splitext(file_basename)[0]
|
| 207 |
+
file_extname = osp.splitext(filename)[1]
|
| 208 |
+
support_templates = dict(
|
| 209 |
+
fileDirname=file_dirname,
|
| 210 |
+
fileBasename=file_basename,
|
| 211 |
+
fileBasenameNoExtension=file_basename_no_extension,
|
| 212 |
+
fileExtname=file_extname,
|
| 213 |
+
)
|
| 214 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 215 |
+
config_file = f.read()
|
| 216 |
+
for key, value in support_templates.items():
|
| 217 |
+
regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
|
| 218 |
+
value = value.replace("\\", "/")
|
| 219 |
+
config_file = re.sub(regexp, value, config_file)
|
| 220 |
+
with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
|
| 221 |
+
tmp_config_file.write(config_file)
|
| 222 |
+
|
| 223 |
+
@staticmethod
|
| 224 |
+
def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict:
|
| 225 |
+
b = b.copy()
|
| 226 |
+
for k, v in a.items():
|
| 227 |
+
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
| 228 |
+
k = int(k)
|
| 229 |
+
if len(b) <= k:
|
| 230 |
+
raise KeyError(f"Index {k} exceeds the length of list {b}")
|
| 231 |
+
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
| 232 |
+
elif isinstance(v, dict):
|
| 233 |
+
if k in b and not v.pop(DELETE_KEY, False):
|
| 234 |
+
allowed_types = (dict, list) if allow_list_keys else dict
|
| 235 |
+
if not isinstance(b[k], allowed_types):
|
| 236 |
+
raise TypeError(
|
| 237 |
+
f"{k}={v} in child config cannot inherit from "
|
| 238 |
+
f"base because {k} is a dict in the child config "
|
| 239 |
+
f"but is of type {type(b[k])} in base config. "
|
| 240 |
+
f"You may set `{DELETE_KEY}=True` to ignore the "
|
| 241 |
+
f"base config."
|
| 242 |
+
)
|
| 243 |
+
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
| 244 |
+
else:
|
| 245 |
+
b[k] = ConfigDict(v)
|
| 246 |
+
else:
|
| 247 |
+
b[k] = v
|
| 248 |
+
return b
|
| 249 |
+
|
| 250 |
+
def to_dict(self) -> Any:
|
| 251 |
+
def convert_configdict(obj):
|
| 252 |
+
if isinstance(obj, ConfigDict):
|
| 253 |
+
return {k: convert_configdict(v) for k, v in obj.items()}
|
| 254 |
+
elif isinstance(obj, dict):
|
| 255 |
+
return {k: convert_configdict(v) for k, v in obj.items()}
|
| 256 |
+
elif isinstance(obj, (list, tuple)):
|
| 257 |
+
return [convert_configdict(item) for item in obj]
|
| 258 |
+
else:
|
| 259 |
+
return obj
|
| 260 |
+
|
| 261 |
+
return convert_configdict(self._cfg_dict)
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def from_dict(cls, cfg_dict: dict, filename: Optional[str] = None) -> "Config":
|
| 265 |
+
return cls(cfg_dict=cfg_dict, filename=filename)
|
| 266 |
+
|
| 267 |
+
def save_yaml(self, filename: str) -> None:
|
| 268 |
+
with open(filename, "w", encoding="utf-8") as f:
|
| 269 |
+
yaml.safe_dump(self.to_dict(), f, default_flow_style=False, indent=2)
|
| 270 |
+
|
| 271 |
+
@classmethod
|
| 272 |
+
def load_yaml(cls, filename: str) -> "Config":
|
| 273 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 274 |
+
cfg_dict = yaml.safe_load(f)
|
| 275 |
+
return cls.from_dict(cfg_dict, filename=filename)
|
| 276 |
+
|
| 277 |
+
def __repr__(self) -> str:
|
| 278 |
+
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
|
| 279 |
+
|
| 280 |
+
def __len__(self) -> int:
|
| 281 |
+
return len(self._cfg_dict)
|
| 282 |
+
|
| 283 |
+
def __getattr__(self, name: str) -> Any:
|
| 284 |
+
return getattr(self._cfg_dict, name)
|
| 285 |
+
|
| 286 |
+
def __getitem__(self, name: str) -> Any:
|
| 287 |
+
return self._cfg_dict.__getitem__(name)
|
| 288 |
+
|
| 289 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 290 |
+
if isinstance(value, dict):
|
| 291 |
+
value = ConfigDict(value)
|
| 292 |
+
self._cfg_dict.__setattr__(name, value)
|
| 293 |
+
|
| 294 |
+
def __setitem__(self, name: str, value: Any) -> None:
|
| 295 |
+
if isinstance(value, dict):
|
| 296 |
+
value = ConfigDict(value)
|
| 297 |
+
self._cfg_dict.__setitem__(name, value)
|
| 298 |
+
|
| 299 |
+
def __iter__(self) -> Iterator[Any]:
|
| 300 |
+
return iter(self._cfg_dict)
|
| 301 |
+
|
| 302 |
+
def __getstate__(self) -> tuple[dict, str, str]:
|
| 303 |
+
return (self._cfg_dict, self._filename, self._text)
|
| 304 |
+
|
| 305 |
+
def __copy__(self) -> "Config":
|
| 306 |
+
cls = self.__class__
|
| 307 |
+
other = cls.__new__(cls)
|
| 308 |
+
other.__dict__.update(self.__dict__)
|
| 309 |
+
|
| 310 |
+
return other
|
| 311 |
+
|
| 312 |
+
def __deepcopy__(self, memo: dict) -> "Config":
|
| 313 |
+
cls = self.__class__
|
| 314 |
+
other = cls.__new__(cls)
|
| 315 |
+
memo[id(self)] = other
|
| 316 |
+
|
| 317 |
+
for key, value in self.__dict__.items():
|
| 318 |
+
super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
|
| 319 |
+
|
| 320 |
+
return other
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class ConfigDict(Dict):
|
| 324 |
+
def __missing__(self, name: str) -> NoReturn:
|
| 325 |
+
raise KeyError(name)
|
| 326 |
+
|
| 327 |
+
def __getattr__(self, name: str) -> Any:
|
| 328 |
+
try:
|
| 329 |
+
return self[name]
|
| 330 |
+
except KeyError:
|
| 331 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
| 332 |
+
|
| 333 |
+
def to_dict(self) -> Any:
|
| 334 |
+
def convert_configdict(obj):
|
| 335 |
+
if isinstance(obj, ConfigDict):
|
| 336 |
+
return {k: convert_configdict(v) for k, v in obj.items()}
|
| 337 |
+
elif isinstance(obj, dict):
|
| 338 |
+
return {k: convert_configdict(v) for k, v in obj.items()}
|
| 339 |
+
elif isinstance(obj, (list, tuple)):
|
| 340 |
+
return [convert_configdict(item) for item in obj]
|
| 341 |
+
else:
|
| 342 |
+
return obj
|
| 343 |
+
|
| 344 |
+
return convert_configdict(dict(self))
|
hymotion/utils/geometry.py
ADDED
|
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
|
| 10 |
+
"""
|
| 11 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
| 12 |
+
using Gram--Schmidt orthogonalization per Section B of [1].
|
| 13 |
+
Args:
|
| 14 |
+
d6: 6D rotation representation, of size (*, 6)
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
batch of rotation matrices of size (*, 3, 3)
|
| 18 |
+
|
| 19 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 20 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 21 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 22 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 26 |
+
b1 = F.normalize(a1, dim=-1)
|
| 27 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 28 |
+
b2 = F.normalize(b2, dim=-1)
|
| 29 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 30 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def matrix_to_rotation_6d(matrix: Tensor) -> Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
| 36 |
+
by dropping the last row. Note that 6D representation is not unique.
|
| 37 |
+
Args:
|
| 38 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
6D rotation representation, of size (*, 6)
|
| 42 |
+
|
| 43 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 44 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 45 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 46 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 47 |
+
"""
|
| 48 |
+
batch_dim = matrix.size()[:-2]
|
| 49 |
+
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def standardize_quaternion(quaternions: Tensor) -> Tensor:
|
| 53 |
+
"""
|
| 54 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 55 |
+
part is non negative.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
quaternions: Quaternions with real part first,
|
| 59 |
+
as tensor of shape (..., 4).
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 63 |
+
"""
|
| 64 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _sqrt_positive_part(x: Tensor) -> Tensor:
|
| 68 |
+
"""Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0."""
|
| 69 |
+
ret = torch.zeros_like(x)
|
| 70 |
+
positive_mask = x > 0
|
| 71 |
+
if torch.is_grad_enabled():
|
| 72 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 73 |
+
else:
|
| 74 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 75 |
+
return ret
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def matrix_to_quaternion(matrix: Tensor) -> Tensor:
|
| 79 |
+
"""Convert rotations given as rotation matrices to quaternions.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 86 |
+
"""
|
| 87 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 88 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 89 |
+
|
| 90 |
+
batch_dim = matrix.shape[:-2]
|
| 91 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 92 |
+
|
| 93 |
+
q_abs = _sqrt_positive_part(
|
| 94 |
+
torch.stack(
|
| 95 |
+
[
|
| 96 |
+
1.0 + m00 + m11 + m22,
|
| 97 |
+
1.0 + m00 - m11 - m22,
|
| 98 |
+
1.0 - m00 + m11 - m22,
|
| 99 |
+
1.0 - m00 - m11 + m22,
|
| 100 |
+
],
|
| 101 |
+
dim=-1,
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 106 |
+
quat_by_rijk = torch.stack(
|
| 107 |
+
[
|
| 108 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 109 |
+
# `int`.
|
| 110 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 111 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 112 |
+
# `int`.
|
| 113 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 114 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 115 |
+
# `int`.
|
| 116 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 117 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 118 |
+
# `int`.
|
| 119 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 120 |
+
],
|
| 121 |
+
dim=-2,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 125 |
+
# the candidate won't be picked.
|
| 126 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 127 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 128 |
+
|
| 129 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 130 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 131 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 132 |
+
return standardize_quaternion(out)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def quaternion_to_axis_angle(quaternions: Tensor) -> Tensor:
|
| 136 |
+
"""Convert rotations given as quaternions to axis/angle.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
quaternions: quaternions with real part first,
|
| 140 |
+
as tensor of shape (..., 4).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 144 |
+
of shape (..., 3), where the magnitude is the angle
|
| 145 |
+
turned anticlockwise in radians around the vector's
|
| 146 |
+
direction.
|
| 147 |
+
"""
|
| 148 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 149 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 150 |
+
angles = 2 * half_angles
|
| 151 |
+
eps = 1e-6
|
| 152 |
+
small_angles = angles.abs() < eps
|
| 153 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 154 |
+
sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 155 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 156 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 157 |
+
sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 158 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def matrix_to_axis_angle(matrix: Tensor) -> Tensor:
|
| 162 |
+
"""Convert rotations given as rotation matrices to axis/angle.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 169 |
+
of shape (..., 3), where the magnitude is the angle
|
| 170 |
+
turned anticlockwise in radians around the vector's
|
| 171 |
+
direction.
|
| 172 |
+
"""
|
| 173 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
|
| 177 |
+
"""Convert rotations given as quaternions to rotation matrices.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
quaternions: quaternions with real part first,
|
| 181 |
+
as tensor of shape (..., 4).
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 185 |
+
"""
|
| 186 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 187 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 188 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 189 |
+
|
| 190 |
+
o = torch.stack(
|
| 191 |
+
(
|
| 192 |
+
1 - two_s * (j * j + k * k),
|
| 193 |
+
two_s * (i * j - k * r),
|
| 194 |
+
two_s * (i * k + j * r),
|
| 195 |
+
two_s * (i * j + k * r),
|
| 196 |
+
1 - two_s * (i * i + k * k),
|
| 197 |
+
two_s * (j * k - i * r),
|
| 198 |
+
two_s * (i * k - j * r),
|
| 199 |
+
two_s * (j * k + i * r),
|
| 200 |
+
1 - two_s * (i * i + j * j),
|
| 201 |
+
),
|
| 202 |
+
-1,
|
| 203 |
+
)
|
| 204 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor:
|
| 208 |
+
"""Convert rotations given as axis/angle to quaternions.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 212 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 213 |
+
the angle turned anticlockwise in radians around the
|
| 214 |
+
vector's direction.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 218 |
+
"""
|
| 219 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
| 220 |
+
half_angles = angles * 0.5
|
| 221 |
+
eps = 1e-6
|
| 222 |
+
small_angles = angles.abs() < eps
|
| 223 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 224 |
+
sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 225 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 226 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 227 |
+
sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 228 |
+
quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)
|
| 229 |
+
return quaternions
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor:
|
| 233 |
+
"""Convert rotations given as axis/angle to rotation matrices.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 237 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 238 |
+
the angle turned anticlockwise in radians around the
|
| 239 |
+
vector's direction.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 243 |
+
"""
|
| 244 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def get_T_w2c_from_wcparams(
|
| 248 |
+
global_orient_w: Tensor, transl_w: Tensor, global_orient_c: Tensor, transl_c: Tensor, offset: Tensor
|
| 249 |
+
) -> Tensor:
|
| 250 |
+
"""
|
| 251 |
+
Args:
|
| 252 |
+
global_orient_w: Tensor, (F, 3)
|
| 253 |
+
transl_w: Tensor, (F, 3)
|
| 254 |
+
global_orient_c: Tensor, (F, 3)
|
| 255 |
+
transl_c: Tensor, (F, 3)
|
| 256 |
+
offset: Tensor, (*, 3)
|
| 257 |
+
Returns:
|
| 258 |
+
T_w2c: Tensor, (F, 4, 4)
|
| 259 |
+
"""
|
| 260 |
+
assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2
|
| 261 |
+
assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2
|
| 262 |
+
|
| 263 |
+
R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3)
|
| 264 |
+
t_w = transl_w # (F, 3)
|
| 265 |
+
R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3)
|
| 266 |
+
t_c = transl_c # (F, 3)
|
| 267 |
+
|
| 268 |
+
R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3)
|
| 269 |
+
t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3)
|
| 270 |
+
T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4)
|
| 271 |
+
T_w2c[..., :3, :3] = R_w2c # (F, 3, 3)
|
| 272 |
+
T_w2c[..., :3, 3] = t_w2c # (F, 3)
|
| 273 |
+
return T_w2c
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]):
|
| 277 |
+
"""
|
| 278 |
+
Args:
|
| 279 |
+
R_w2c: (*, 3, 3)
|
| 280 |
+
Returns:
|
| 281 |
+
R_c2gv: (*, 3, 3)
|
| 282 |
+
"""
|
| 283 |
+
if isinstance(axis_gravity_in_w, list):
|
| 284 |
+
axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord
|
| 285 |
+
axis_z_in_c = torch.tensor([0, 0, 1]).float()
|
| 286 |
+
|
| 287 |
+
# get gv-coord axes in in c-coord
|
| 288 |
+
axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3)
|
| 289 |
+
axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1)
|
| 290 |
+
# normalize
|
| 291 |
+
axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True)
|
| 292 |
+
axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5)
|
| 293 |
+
axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv
|
| 294 |
+
axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1)
|
| 295 |
+
|
| 296 |
+
R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3)
|
| 297 |
+
R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3)
|
| 298 |
+
return R_c2gv
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_c_rootparam(global_orient: Tensor, transl: Tensor, T_w2c: Tensor, offset: Tensor) -> Tuple[Tensor, Tensor]:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
global_orient: Tensor, (F, 3)
|
| 305 |
+
transl: Tensor, (F, 3)
|
| 306 |
+
T_w2c: Tensor, (*, 4, 4)
|
| 307 |
+
offset: Tensor, (3,)
|
| 308 |
+
Returns:
|
| 309 |
+
R_c: Tensor, (F, 3)
|
| 310 |
+
t_c: Tensor, (F, 3)
|
| 311 |
+
"""
|
| 312 |
+
assert global_orient.shape == transl.shape and len(global_orient.shape) == 2
|
| 313 |
+
R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3)
|
| 314 |
+
t_w = transl # (F, 3)
|
| 315 |
+
|
| 316 |
+
R_w2c = T_w2c[..., :3, :3] # (*, 3, 3)
|
| 317 |
+
t_w2c = T_w2c[..., :3, 3] # (*, 3)
|
| 318 |
+
if len(R_w2c.shape) == 2:
|
| 319 |
+
R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3)
|
| 320 |
+
t_w2c = t_w2c[None].expand(t_w.size(0), -1)
|
| 321 |
+
|
| 322 |
+
R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3)
|
| 323 |
+
t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3)
|
| 324 |
+
return R_c, t_c
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def compute_cam_angvel(R_w2c, padding_last=True):
|
| 328 |
+
"""
|
| 329 |
+
R_w2c : (F, 3, 3)
|
| 330 |
+
"""
|
| 331 |
+
# R @ R0 = R1, so R = R1 @ R0^T
|
| 332 |
+
cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6)
|
| 333 |
+
# cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS
|
| 334 |
+
assert padding_last
|
| 335 |
+
cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6)
|
| 336 |
+
return cam_angvel.float()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def rot6d_to_rotation_matrix(rot6d):
|
| 340 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
| 341 |
+
|
| 342 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
| 343 |
+
Args:
|
| 344 |
+
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
|
| 345 |
+
Returns:
|
| 346 |
+
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
|
| 347 |
+
"""
|
| 348 |
+
# x = rot6d.view(-1, 3, 2)
|
| 349 |
+
x = rot6d.view(*rot6d.shape[:-1], 3, 2)
|
| 350 |
+
a1 = x[..., 0]
|
| 351 |
+
a2 = x[..., 1]
|
| 352 |
+
b1 = F.normalize(a1, dim=-1)
|
| 353 |
+
b2 = F.normalize(a2 - torch.einsum("...i,...i->...", b1, a2).unsqueeze(-1) * b1, dim=-1)
|
| 354 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 355 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def rotation_matrix_to_rot6d(rotation_matrix):
|
| 359 |
+
"""Convert 3x3 rotation matrix to 6D rotation representation.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
|
| 363 |
+
Returns:
|
| 364 |
+
rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
|
| 365 |
+
"""
|
| 366 |
+
v1 = rotation_matrix[..., 0:1]
|
| 367 |
+
v2 = rotation_matrix[..., 1:2]
|
| 368 |
+
rot6d = torch.cat([v1, v2], dim=-1).reshape(*v1.shape[:-2], 6)
|
| 369 |
+
return rot6d
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def quaternion_to_rotation_matrix(quaternion):
|
| 373 |
+
"""Convert quaternion coefficients to rotation matrix.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
|
| 377 |
+
Returns:
|
| 378 |
+
rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3)
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
norm_quaternion = quaternion
|
| 382 |
+
norm_quaternion = norm_quaternion / norm_quaternion.norm(p=2, dim=-1, keepdim=True)
|
| 383 |
+
w, x, y, z = norm_quaternion[..., 0], norm_quaternion[..., 1], norm_quaternion[..., 2], norm_quaternion[..., 3]
|
| 384 |
+
|
| 385 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
| 386 |
+
wx, wy, wz = w * x, w * y, w * z
|
| 387 |
+
xy, xz, yz = x * y, x * z, y * z
|
| 388 |
+
|
| 389 |
+
rotation_matrix = torch.stack(
|
| 390 |
+
[
|
| 391 |
+
w2 + x2 - y2 - z2,
|
| 392 |
+
2 * xy - 2 * wz,
|
| 393 |
+
2 * wy + 2 * xz,
|
| 394 |
+
2 * wz + 2 * xy,
|
| 395 |
+
w2 - x2 + y2 - z2,
|
| 396 |
+
2 * yz - 2 * wx,
|
| 397 |
+
2 * xz - 2 * wy,
|
| 398 |
+
2 * wx + 2 * yz,
|
| 399 |
+
w2 - x2 - y2 + z2,
|
| 400 |
+
],
|
| 401 |
+
dim=-1,
|
| 402 |
+
)
|
| 403 |
+
rotation_matrix = rotation_matrix.view(*quaternion.shape[:-1], 3, 3)
|
| 404 |
+
return rotation_matrix
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def quaternion_to_angle_axis(quaternion: Tensor) -> Tensor:
|
| 408 |
+
"""
|
| 409 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 410 |
+
|
| 411 |
+
Convert quaternion vector to angle axis of rotation.
|
| 412 |
+
|
| 413 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
quaternion (Tensor): tensor with quaternions.
|
| 417 |
+
|
| 418 |
+
Return:
|
| 419 |
+
Tensor: tensor with angle axis of rotation.
|
| 420 |
+
|
| 421 |
+
Shape:
|
| 422 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
| 423 |
+
- Output: :math:`(*, 3)`
|
| 424 |
+
|
| 425 |
+
Example:
|
| 426 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
| 427 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
| 428 |
+
"""
|
| 429 |
+
if not torch.is_tensor(quaternion):
|
| 430 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
|
| 431 |
+
|
| 432 |
+
if not quaternion.shape[-1] == 4:
|
| 433 |
+
raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape))
|
| 434 |
+
# unpack input and compute conversion
|
| 435 |
+
q1: torch.Tensor = quaternion[..., 1]
|
| 436 |
+
q2: torch.Tensor = quaternion[..., 2]
|
| 437 |
+
q3: torch.Tensor = quaternion[..., 3]
|
| 438 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
| 439 |
+
|
| 440 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
| 441 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
| 442 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
| 443 |
+
cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
| 447 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
| 448 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
| 449 |
+
|
| 450 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
| 451 |
+
angle_axis[..., 0] += q1 * k
|
| 452 |
+
angle_axis[..., 1] += q2 * k
|
| 453 |
+
angle_axis[..., 2] += q3 * k
|
| 454 |
+
return angle_axis
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
| 458 |
+
"""
|
| 459 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 460 |
+
|
| 461 |
+
Convert 3x4 rotation matrix to 4d quaternion vector
|
| 462 |
+
|
| 463 |
+
This algorithm is based on algorithm described in
|
| 464 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
| 468 |
+
|
| 469 |
+
Return:
|
| 470 |
+
Tensor: the rotation in quaternion
|
| 471 |
+
|
| 472 |
+
Shape:
|
| 473 |
+
- Input: :math:`(N, 3, 4)`
|
| 474 |
+
- Output: :math:`(N, 4)`
|
| 475 |
+
|
| 476 |
+
Example:
|
| 477 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
| 478 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
| 479 |
+
"""
|
| 480 |
+
if not torch.is_tensor(rotation_matrix):
|
| 481 |
+
raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
|
| 482 |
+
|
| 483 |
+
if len(rotation_matrix.shape) > 3:
|
| 484 |
+
raise ValueError("Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape))
|
| 485 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
| 486 |
+
hom = (
|
| 487 |
+
torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
|
| 488 |
+
.reshape(1, 3, 1)
|
| 489 |
+
.expand(rotation_matrix.shape[0], -1, -1)
|
| 490 |
+
)
|
| 491 |
+
rotation_matrix = torch.cat([rotation_matrix, hom], dim=-1)
|
| 492 |
+
|
| 493 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
| 494 |
+
|
| 495 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
| 496 |
+
|
| 497 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
| 498 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
| 499 |
+
|
| 500 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 501 |
+
q0 = torch.stack(
|
| 502 |
+
[rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]],
|
| 503 |
+
-1,
|
| 504 |
+
)
|
| 505 |
+
t0_rep = t0.repeat(4, 1).t()
|
| 506 |
+
|
| 507 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
| 508 |
+
q1 = torch.stack(
|
| 509 |
+
[rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]],
|
| 510 |
+
-1,
|
| 511 |
+
)
|
| 512 |
+
t1_rep = t1.repeat(4, 1).t()
|
| 513 |
+
|
| 514 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 515 |
+
q2 = torch.stack(
|
| 516 |
+
[rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2],
|
| 517 |
+
-1,
|
| 518 |
+
)
|
| 519 |
+
t2_rep = t2.repeat(4, 1).t()
|
| 520 |
+
|
| 521 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
| 522 |
+
q3 = torch.stack(
|
| 523 |
+
[t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]],
|
| 524 |
+
-1,
|
| 525 |
+
)
|
| 526 |
+
t3_rep = t3.repeat(4, 1).t()
|
| 527 |
+
|
| 528 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
| 529 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
| 530 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
| 531 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
| 532 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
| 533 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
| 534 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
| 535 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
| 536 |
+
|
| 537 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
| 538 |
+
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa # noqa
|
| 539 |
+
q *= 0.5
|
| 540 |
+
return q
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
| 544 |
+
"""
|
| 545 |
+
This function is borrowed from https://github.com/kornia/kornia
|
| 546 |
+
|
| 547 |
+
Convert 3x4 rotation matrix to Rodrigues vector
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
rotation_matrix (Tensor): rotation matrix.
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
Tensor: Rodrigues vector transformation.
|
| 554 |
+
|
| 555 |
+
Shape:
|
| 556 |
+
- Input: :math:`(N, 3, 4)`
|
| 557 |
+
- Output: :math:`(N, 3)`
|
| 558 |
+
|
| 559 |
+
Example:
|
| 560 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
| 561 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
| 562 |
+
"""
|
| 563 |
+
origin_shape = rotation_matrix.shape[:-2]
|
| 564 |
+
flat_rot = rotation_matrix.reshape(-1, *rotation_matrix.shape[-2:])
|
| 565 |
+
if flat_rot.shape[1:] == (3, 3):
|
| 566 |
+
rot_mat = flat_rot
|
| 567 |
+
hom = (
|
| 568 |
+
torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
|
| 569 |
+
.reshape(1, 3, 1)
|
| 570 |
+
.expand(rot_mat.shape[0], -1, -1)
|
| 571 |
+
)
|
| 572 |
+
flat_rot = torch.cat([rot_mat, hom], dim=-1)
|
| 573 |
+
|
| 574 |
+
quaternion = rotation_matrix_to_quaternion(flat_rot)
|
| 575 |
+
aa = quaternion_to_angle_axis(quaternion)
|
| 576 |
+
aa[torch.isnan(aa)] = 0.0
|
| 577 |
+
aa = aa.reshape(*origin_shape, 3)
|
| 578 |
+
return aa
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def quat_to_rotmat(quat):
|
| 582 |
+
"""Convert quaternion coefficients to rotation matrix.
|
| 583 |
+
|
| 584 |
+
Args:
|
| 585 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
| 586 |
+
Returns:
|
| 587 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
| 588 |
+
"""
|
| 589 |
+
norm_quat = quat
|
| 590 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
| 591 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
| 592 |
+
|
| 593 |
+
B = quat.size(0)
|
| 594 |
+
|
| 595 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
| 596 |
+
wx, wy, wz = w * x, w * y, w * z
|
| 597 |
+
xy, xz, yz = x * y, x * z, y * z
|
| 598 |
+
|
| 599 |
+
rotMat = torch.stack(
|
| 600 |
+
[
|
| 601 |
+
w2 + x2 - y2 - z2,
|
| 602 |
+
2 * xy - 2 * wz,
|
| 603 |
+
2 * wy + 2 * xz,
|
| 604 |
+
2 * wz + 2 * xy,
|
| 605 |
+
w2 - x2 + y2 - z2,
|
| 606 |
+
2 * yz - 2 * wx,
|
| 607 |
+
2 * xz - 2 * wy,
|
| 608 |
+
2 * wx + 2 * yz,
|
| 609 |
+
w2 - x2 - y2 + z2,
|
| 610 |
+
],
|
| 611 |
+
dim=1,
|
| 612 |
+
).view(B, 3, 3)
|
| 613 |
+
return rotMat
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def angle_axis_to_rotation_matrix(theta):
|
| 617 |
+
"""Convert axis-angle representation to rotation matrix.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
theta: size = [B, 3]
|
| 621 |
+
Returns:
|
| 622 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
| 623 |
+
"""
|
| 624 |
+
origin_shape = theta.shape[:-1]
|
| 625 |
+
flat_theta = theta.reshape(-1, 3)
|
| 626 |
+
l1norm = torch.norm(flat_theta + 1e-8, p=2, dim=1)
|
| 627 |
+
angle = torch.unsqueeze(l1norm, -1)
|
| 628 |
+
normalized = torch.div(flat_theta, angle)
|
| 629 |
+
angle = angle * 0.5
|
| 630 |
+
v_cos = torch.cos(angle)
|
| 631 |
+
v_sin = torch.sin(angle)
|
| 632 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
| 633 |
+
rot_mat = quat_to_rotmat(quat)
|
| 634 |
+
return rot_mat.reshape(*origin_shape, 3, 3)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def rotation_matrix_to_euler_angles(rotation_matrix):
|
| 638 |
+
"""Convert 3x3 rotation matrix to Euler angles."""
|
| 639 |
+
is_torch = False
|
| 640 |
+
if isinstance(rotation_matrix, Tensor):
|
| 641 |
+
is_torch = True
|
| 642 |
+
device = rotation_matrix.device
|
| 643 |
+
rotation_matrix = rotation_matrix.cpu().numpy()
|
| 644 |
+
from scipy.spatial.transform import Rotation
|
| 645 |
+
|
| 646 |
+
rot_flat = rotation_matrix.reshape(-1, 3, 3)
|
| 647 |
+
euler_angles = Rotation.from_matrix(rot_flat).as_euler("xyz", degrees=True)
|
| 648 |
+
if is_torch:
|
| 649 |
+
return torch.from_numpy(euler_angles).to(device)
|
| 650 |
+
return euler_angles
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def euler_angles_to_rotation_matrix(euler_angles, degrees=True):
|
| 654 |
+
"""Convert Euler angles to 3x3 rotation matrix.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
euler_angles: Euler angles in xyz order, shape = [B, 3] or any shape with last dimension 3
|
| 658 |
+
degrees: Whether the angles are in degrees (True) or radians (False)
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
Rotation matrix corresponding to the Euler angles, shape = [..., 3, 3]
|
| 662 |
+
"""
|
| 663 |
+
from scipy.spatial.transform import Rotation
|
| 664 |
+
|
| 665 |
+
orig_shape = euler_angles.shape[:-1]
|
| 666 |
+
euler_flat = euler_angles.reshape(-1, 3)
|
| 667 |
+
rot_flat = Rotation.from_euler("xyz", euler_flat, degrees=degrees).as_matrix()
|
| 668 |
+
return rot_flat.reshape(*orig_shape, 3, 3)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def get_local_transl_vel(transl, global_orient_R, fps=30):
|
| 672 |
+
"""
|
| 673 |
+
transl velocity is in local coordinate (or, SMPL-coord)
|
| 674 |
+
Args:
|
| 675 |
+
transl: (*, L, 3)
|
| 676 |
+
global_orient: (*, L, 3, 3)
|
| 677 |
+
Returns:
|
| 678 |
+
transl_vel: (*, L, 3)
|
| 679 |
+
"""
|
| 680 |
+
transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3)
|
| 681 |
+
transl_vel = torch.cat([torch.zeros_like(transl_vel[:1]), transl_vel], dim=-2) # (B, L, 3) last-padding
|
| 682 |
+
transl_vel = transl_vel * fps
|
| 683 |
+
|
| 684 |
+
# v_local = R^T @ v_global
|
| 685 |
+
local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel)
|
| 686 |
+
return local_transl_vel
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg):
|
| 690 |
+
s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2]
|
| 691 |
+
focal_length = K_fullimg[..., 0, 0]
|
| 692 |
+
|
| 693 |
+
icx = K_fullimg[..., 0, 2]
|
| 694 |
+
icy = K_fullimg[..., 1, 2]
|
| 695 |
+
sb = s * bbx_xys[..., 2]
|
| 696 |
+
cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9)
|
| 697 |
+
cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9)
|
| 698 |
+
tz = 2 * focal_length / (sb + 1e-9)
|
| 699 |
+
|
| 700 |
+
cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
|
| 701 |
+
return cam_t
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def quaternion_fix_continuity(q: Tensor) -> Tensor:
|
| 705 |
+
"""Force quaternion continuity across the time dimension by selecting the representation (q or -q) with minimal
|
| 706 |
+
distance (or, equivalently, maximal dot product) between two consecutive frames."""
|
| 707 |
+
assert q.ndim in (
|
| 708 |
+
2,
|
| 709 |
+
3,
|
| 710 |
+
), f"Expected 3D tensor (L, J, 4), or 2D tensor (L, 4), but got shape {q.shape}"
|
| 711 |
+
assert q.shape[-1] == 4, f"Last dimension should be 4 for quaternions, got {q.shape[-1]}"
|
| 712 |
+
if q.shape[0] <= 1:
|
| 713 |
+
return q.clone() # single frame or empty sequence, no need to process
|
| 714 |
+
|
| 715 |
+
result = q.clone()
|
| 716 |
+
# compute the dot product between consecutive frames (L-1, J) or (L-1)
|
| 717 |
+
dot_products = torch.sum(q[1:] * q[:-1], dim=-1)
|
| 718 |
+
# find the negative dot product (indicates need to flip sign)
|
| 719 |
+
flip_mask = dot_products < 0
|
| 720 |
+
# accumulate the flip mask, ensure consistency
|
| 721 |
+
# if a frame needs to be flipped, all subsequent frames need to be flipped the same number of times
|
| 722 |
+
flip_mask = (torch.cumsum(flip_mask.int(), dim=0) % 2).bool()
|
| 723 |
+
# flip the sign of the frames that need to be flipped
|
| 724 |
+
result[1:][flip_mask] *= -1
|
| 725 |
+
return result
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def rot_mat2trans_mat(rot_mat: np.ndarray) -> np.ndarray:
|
| 729 |
+
# assert rot_mat.shape == (3, 3)
|
| 730 |
+
trans_mat = np.identity(4)
|
| 731 |
+
if len(rot_mat.shape) == 2:
|
| 732 |
+
trans_mat = trans_mat
|
| 733 |
+
elif len(rot_mat.shape) == 3:
|
| 734 |
+
trans_mat = np.tile(trans_mat, [rot_mat.shape[0], 1, 1])
|
| 735 |
+
elif len(rot_mat.shape) == 4:
|
| 736 |
+
trans_mat = np.tile(trans_mat, [rot_mat.shape[0], rot_mat.shape[1], 1, 1])
|
| 737 |
+
else:
|
| 738 |
+
raise NotImplementedError
|
| 739 |
+
trans_mat[..., :3, :3] = rot_mat
|
| 740 |
+
return trans_mat
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def trans2trans_mat(trans: np.ndarray) -> np.ndarray:
|
| 744 |
+
assert trans.shape[-1] == 3
|
| 745 |
+
assert (len(trans.shape) == 1) or (len(trans.shape) == 2) or (len(trans.shape) == 3), trans.shape
|
| 746 |
+
if len(trans.shape) == 1:
|
| 747 |
+
trans_mat = np.identity(4)
|
| 748 |
+
trans_mat[:3, 3] = trans
|
| 749 |
+
elif len(trans.shape) == 2:
|
| 750 |
+
trans_mat = np.tile(np.identity(4), [trans.shape[0], 1, 1])
|
| 751 |
+
trans_mat[:, :3, 3] = trans
|
| 752 |
+
elif len(trans.shape) == 3:
|
| 753 |
+
trans_mat = np.tile(np.identity(4), [trans.shape[0], trans.shape[1], 1, 1])
|
| 754 |
+
trans_mat[:, :, :3, 3] = trans
|
| 755 |
+
else:
|
| 756 |
+
raise NotImplementedError
|
| 757 |
+
return trans_mat
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
def gaussian_kernel1d(sigma: float, order: int, radius: int) -> np.ndarray:
|
| 761 |
+
"""Computes a 1D Gaussian convolution kernel.
|
| 762 |
+
|
| 763 |
+
(from scipy)
|
| 764 |
+
"""
|
| 765 |
+
if order < 0:
|
| 766 |
+
raise ValueError("order must be non-negative")
|
| 767 |
+
exponent_range = np.arange(order + 1)
|
| 768 |
+
sigma2 = sigma * sigma
|
| 769 |
+
x = np.arange(-radius, radius + 1)
|
| 770 |
+
phi_x = np.exp(-0.5 / sigma2 * x**2)
|
| 771 |
+
phi_x = phi_x / phi_x.sum()
|
| 772 |
+
|
| 773 |
+
if order == 0:
|
| 774 |
+
return phi_x
|
| 775 |
+
else:
|
| 776 |
+
# f(x) = q(x) * phi(x) = q(x) * exp(p(x))
|
| 777 |
+
# f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
|
| 778 |
+
# p'(x) = -1 / sigma ** 2
|
| 779 |
+
# Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
|
| 780 |
+
# coefficients of q(x)
|
| 781 |
+
q = np.zeros(order + 1)
|
| 782 |
+
q[0] = 1
|
| 783 |
+
D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
|
| 784 |
+
P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
|
| 785 |
+
Q_deriv = D + P
|
| 786 |
+
for _ in range(order):
|
| 787 |
+
q = Q_deriv.dot(q)
|
| 788 |
+
q = (x[:, None] ** exponent_range).dot(q)
|
| 789 |
+
return q * phi_x
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def slice_seq_with_padding(whole_seq: np.ndarray, middle_idx: int, length: int) -> np.ndarray:
|
| 793 |
+
whole_seq_padded = whole_seq.copy()
|
| 794 |
+
if middle_idx - length // 2 < 0:
|
| 795 |
+
# need padding
|
| 796 |
+
l_pad_len = length // 2 - middle_idx
|
| 797 |
+
whole_seq_padded = np.concatenate([np.stack([whole_seq_padded[0]] * l_pad_len), whole_seq_padded], axis=0)
|
| 798 |
+
else:
|
| 799 |
+
l_pad_len = 0
|
| 800 |
+
if middle_idx + length - length // 2 > len(whole_seq):
|
| 801 |
+
r_pad_len = middle_idx + length - length // 2 - len(whole_seq)
|
| 802 |
+
whole_seq_padded = np.concatenate([whole_seq_padded, np.stack([whole_seq_padded[-1]] * r_pad_len)], axis=0)
|
| 803 |
+
else:
|
| 804 |
+
r_pad_len = 0
|
| 805 |
+
assert len(whole_seq_padded) == len(whole_seq) + l_pad_len + r_pad_len
|
| 806 |
+
middle_idx_padded = middle_idx + l_pad_len
|
| 807 |
+
assert middle_idx_padded - length // 2 >= 0
|
| 808 |
+
assert middle_idx_padded + length - length // 2 <= len(whole_seq_padded)
|
| 809 |
+
return whole_seq_padded[middle_idx_padded - length // 2 : middle_idx_padded - length // 2 + length]
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def wavg_quaternion_markley(Q: np.ndarray, weights: np.ndarray) -> np.ndarray:
|
| 813 |
+
"""
|
| 814 |
+
Averaging Quaternions.
|
| 815 |
+
This is a python implementation of Tolga Birdal's algorithm by https://stackoverflow.com/a/49690919
|
| 816 |
+
|
| 817 |
+
Arguments:
|
| 818 |
+
Q(ndarray): an Mx4 ndarray of quaternions.
|
| 819 |
+
weights(list): an M elements list, a weight for each quaternion.
|
| 820 |
+
|
| 821 |
+
refer to Tolga Birdal's matlab implementation on
|
| 822 |
+
https://ww2.mathworks.cn/matlabcentral/fileexchange/40098-tolgabirdal-averaging_quaternions?s_tid=prof_contriblnk&s_tid=mwa_osa_a
|
| 823 |
+
by Tolga Birdal
|
| 824 |
+
Q is an Mx4 matrix of quaternions. weights is an Mx1 vector, a weight for
|
| 825 |
+
each quaternion.
|
| 826 |
+
Qavg is the weighted average quaternion
|
| 827 |
+
This function is especially useful for example when clustering poses
|
| 828 |
+
after a matching process. In such cases a form of weighting per rotation
|
| 829 |
+
is available (e.g. number of votes), which can guide the trust towards a
|
| 830 |
+
specific pose. weights might then be interpreted as the vector of votes
|
| 831 |
+
per pose.
|
| 832 |
+
Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman.
|
| 833 |
+
"Averaging quaternions." Journal of Guidance, Control, and Dynamics 30,
|
| 834 |
+
no. 4 (2007): 1193-1197.
|
| 835 |
+
"""
|
| 836 |
+
|
| 837 |
+
# Form the symmetric accumulator matrix
|
| 838 |
+
# pdb.set_trace()
|
| 839 |
+
A = np.zeros((4, 4))
|
| 840 |
+
M = Q.shape[0]
|
| 841 |
+
wSum = 0
|
| 842 |
+
|
| 843 |
+
for i in range(M):
|
| 844 |
+
q = Q[i, :]
|
| 845 |
+
w_i = weights[i]
|
| 846 |
+
if q[0] < 0:
|
| 847 |
+
# handle the antipodal configuration
|
| 848 |
+
q = -q
|
| 849 |
+
A += w_i * (np.outer(q, q)) # rank 1 update
|
| 850 |
+
wSum += w_i
|
| 851 |
+
|
| 852 |
+
# scale
|
| 853 |
+
A /= wSum
|
| 854 |
+
|
| 855 |
+
# Get the eigenvector corresponding to largest eigen value
|
| 856 |
+
return np.linalg.eigh(A)[1][:, -1]
|
hymotion/utils/loaders.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_object(module_name, module_args, **extra_args):
|
| 7 |
+
module_args = module_args.copy()
|
| 8 |
+
module_path = ".".join(module_name.split(".")[:-1]).replace("/", ".")
|
| 9 |
+
module = importlib.import_module(module_path)
|
| 10 |
+
name = module_name.split(".")[-1]
|
| 11 |
+
if module_args is None:
|
| 12 |
+
module_args = {}
|
| 13 |
+
module_args.update(extra_args)
|
| 14 |
+
obj = getattr(module, name)(**module_args)
|
| 15 |
+
return obj
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_module(module_name):
|
| 19 |
+
module_path = module_name.split(".")[0].replace("/", ".")
|
| 20 |
+
module = importlib.import_module(module_path)
|
| 21 |
+
name = module_name.split(".")[-1]
|
| 22 |
+
obj = getattr(module, name)
|
| 23 |
+
return obj
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def check_cfg(cfg, global_dict, verbose=True):
|
| 27 |
+
for key, val in cfg.items():
|
| 28 |
+
if isinstance(val, dict):
|
| 29 |
+
check_cfg(val, global_dict, verbose)
|
| 30 |
+
elif isinstance(val, str):
|
| 31 |
+
if val.startswith("$"):
|
| 32 |
+
if verbose:
|
| 33 |
+
print(f" - Update {key} with {val} = {global_dict[val[1:]]}")
|
| 34 |
+
cfg[key] = global_dict[val[1:]]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def read_yaml(yamlname):
|
| 38 |
+
import yaml
|
| 39 |
+
|
| 40 |
+
with open(yamlname, "r", encoding="utf-8") as file:
|
| 41 |
+
try:
|
| 42 |
+
data = yaml.safe_load(file)
|
| 43 |
+
except yaml.constructor.ConstructorError:
|
| 44 |
+
file.seek(0)
|
| 45 |
+
data = yaml.load(file, Loader=yaml.FullLoader)
|
| 46 |
+
if hasattr(data, "to_dict"):
|
| 47 |
+
data = data.to_dict()
|
| 48 |
+
elif hasattr(data, "_cfg_dict"):
|
| 49 |
+
data = dict(data._cfg_dict)
|
| 50 |
+
|
| 51 |
+
return data
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def write_yaml(data, yamlname):
|
| 55 |
+
import yaml
|
| 56 |
+
|
| 57 |
+
with open(yamlname, "w", encoding="utf-8") as file:
|
| 58 |
+
yaml.dump(data, file)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def check_input(data, verbose=True):
|
| 62 |
+
data_parent = {}
|
| 63 |
+
if "input" in data:
|
| 64 |
+
if verbose:
|
| 65 |
+
print(" - Check input file list")
|
| 66 |
+
for filename in data.pop("input"):
|
| 67 |
+
cfg_new = read_yaml(filename)
|
| 68 |
+
data_parent.update(cfg_new)
|
| 69 |
+
return data_parent
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def merge_dict(dict_A, dict_B, key, verbose=True):
|
| 73 |
+
if isinstance(dict_A[key], dict):
|
| 74 |
+
dict_B = dict_B.copy()
|
| 75 |
+
for key2, val2 in dict_A[key].items():
|
| 76 |
+
if key2 in dict_B[key]:
|
| 77 |
+
merge_dict(dict_A[key], dict_B[key], key2, verbose)
|
| 78 |
+
dict_B[key].pop(key2)
|
| 79 |
+
if len(dict_B[key]) > 0:
|
| 80 |
+
if verbose:
|
| 81 |
+
print(f" - Create {key} with {dict_B[key]}")
|
| 82 |
+
for key2, val2 in dict_B[key].items():
|
| 83 |
+
dict_A[key][key2] = val2
|
| 84 |
+
else:
|
| 85 |
+
if verbose:
|
| 86 |
+
print(f" - Update {key} with {dict_B[key]}")
|
| 87 |
+
dict_A[key] = dict_B[key]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def read_config(cfgname, verbose=True):
|
| 91 |
+
data_base = read_yaml(cfgname)
|
| 92 |
+
data_parent = check_input(data_base, verbose)
|
| 93 |
+
# merge the data_base to data_parent
|
| 94 |
+
for key, val in data_parent.items():
|
| 95 |
+
if key in data_base:
|
| 96 |
+
merge_dict(data_parent, data_base, key, verbose)
|
| 97 |
+
if verbose:
|
| 98 |
+
print(data_parent[key])
|
| 99 |
+
data_base.pop(key)
|
| 100 |
+
data_parent.update(data_base)
|
| 101 |
+
data = data_parent
|
| 102 |
+
check_cfg(data, data, verbose)
|
| 103 |
+
return data
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def update_config(config, args):
|
| 107 |
+
for key, value in vars(args).items():
|
| 108 |
+
if key in config.keys() and value is not None:
|
| 109 |
+
config[key] = value
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def read_yaml_full(path):
|
| 113 |
+
import yaml
|
| 114 |
+
|
| 115 |
+
with open(path, "r") as f:
|
| 116 |
+
return yaml.load(f, Loader=yaml.FullLoader)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def check_ceph_path(path):
|
| 120 |
+
import os
|
| 121 |
+
|
| 122 |
+
if os.path.exists(path):
|
| 123 |
+
return path
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"{path} not found")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def read_json(filename):
|
| 129 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 130 |
+
return json.load(f)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def write_json(data, filename):
|
| 134 |
+
with open(filename, "w", encoding="utf-8") as f:
|
| 135 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def load_h5_dataset(filename, ds_name_list=None, parser=None):
|
| 139 |
+
import h5py
|
| 140 |
+
|
| 141 |
+
# ds for dataset
|
| 142 |
+
if "@" in filename:
|
| 143 |
+
filename, start_end = filename.split("@")
|
| 144 |
+
start = int(start_end.split(":")[0])
|
| 145 |
+
end = int(start_end.split(":")[1])
|
| 146 |
+
else:
|
| 147 |
+
start = None
|
| 148 |
+
end = None
|
| 149 |
+
assert os.path.isfile(filename), "cannot find: {}".format(filename)
|
| 150 |
+
|
| 151 |
+
def load_dict(d):
|
| 152 |
+
ds_dict = {}
|
| 153 |
+
for item in d.keys():
|
| 154 |
+
if ds_name_list is not None and item not in ds_name_list:
|
| 155 |
+
continue
|
| 156 |
+
if isinstance(d[item], h5py._hl.dataset.Dataset):
|
| 157 |
+
ds_dict[item] = d[item][()]
|
| 158 |
+
if parser is not None and item in parser:
|
| 159 |
+
ds_dict[item] = parser[item](ds_dict[item])
|
| 160 |
+
elif isinstance(d[item], h5py._hl.group.Group):
|
| 161 |
+
ds_dict[item] = load_dict(d[item])
|
| 162 |
+
for item in d.attrs.keys():
|
| 163 |
+
ds_dict[item] = d.attrs[item]
|
| 164 |
+
return ds_dict
|
| 165 |
+
|
| 166 |
+
with h5py.File(filename, "r") as f:
|
| 167 |
+
ds_dict = load_dict(f)
|
| 168 |
+
for item in f.attrs.keys():
|
| 169 |
+
ds_dict[item] = f.attrs[item]
|
| 170 |
+
if start is not None and end is not None:
|
| 171 |
+
for key in ["LclRotation", "LclTranslation"]:
|
| 172 |
+
ds_dict[key] = ds_dict[key][start:end]
|
| 173 |
+
return ds_dict
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
# hymotion.utils.loaders
|
| 178 |
+
network = load_object("hymotion.utils.base_example.ToyNetwork", {})
|
| 179 |
+
print(network)
|
| 180 |
+
network = load_object("hymotion/utils/base_example.ToyNetwork", {})
|
| 181 |
+
print(network)
|
| 182 |
+
load_object("diffusers.DDIMScheduler", {})
|
| 183 |
+
module = load_object("torch.nn.MSELoss", {"reduction": "none"})
|
| 184 |
+
print(module)
|
hymotion/utils/misc.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from collections.abc import Iterable, Sequence
|
| 3 |
+
from importlib import import_module
|
| 4 |
+
from itertools import repeat
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from typing import Any, Callable, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def is_str(x: Any) -> bool:
|
| 10 |
+
"""Whether the input is an string instance.
|
| 11 |
+
|
| 12 |
+
Note: This method is deprecated since python 2 is no longer supported.
|
| 13 |
+
"""
|
| 14 |
+
return isinstance(x, str)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def is_seq_of(seq: Any, expected_type: Any, seq_type: Any = None) -> bool:
|
| 18 |
+
"""Check whether it is a sequence of some type.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
seq (Sequence): The sequence to be checked.
|
| 22 |
+
expected_type (type): Expected type of sequence items.
|
| 23 |
+
seq_type (type, optional): Expected sequence type.
|
| 24 |
+
Returns:
|
| 25 |
+
bool: Whether the sequence is valid.
|
| 26 |
+
"""
|
| 27 |
+
if seq_type is None:
|
| 28 |
+
exp_seq_type = Sequence
|
| 29 |
+
else:
|
| 30 |
+
assert isinstance(seq_type, type)
|
| 31 |
+
exp_seq_type = seq_type
|
| 32 |
+
if not isinstance(seq, exp_seq_type):
|
| 33 |
+
return False
|
| 34 |
+
for item in seq:
|
| 35 |
+
if not isinstance(item, expected_type):
|
| 36 |
+
return False
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def is_list_of(seq: Any, expected_type: Any) -> bool:
|
| 41 |
+
"""Check whether it is a list of some type.
|
| 42 |
+
|
| 43 |
+
A partial method of :func:`is_seq_of`.
|
| 44 |
+
"""
|
| 45 |
+
return is_seq_of(seq, expected_type, seq_type=list)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def is_tuple_of(seq: Any, expected_type: Any) -> bool:
|
| 49 |
+
"""Check whether it is a tuple of some type.
|
| 50 |
+
|
| 51 |
+
A partial method of :func:`is_seq_of`.
|
| 52 |
+
"""
|
| 53 |
+
return is_seq_of(seq, expected_type, seq_type=tuple)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def import_modules_from_strings(
|
| 57 |
+
imports: Union[list[str], str], allow_failed_imports: bool = False
|
| 58 |
+
) -> Optional[list[Any]]:
|
| 59 |
+
if not imports:
|
| 60 |
+
return
|
| 61 |
+
single_import = False
|
| 62 |
+
if isinstance(imports, str):
|
| 63 |
+
single_import = True
|
| 64 |
+
imports = [imports]
|
| 65 |
+
if not isinstance(imports, list):
|
| 66 |
+
raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
|
| 67 |
+
imported = []
|
| 68 |
+
for imp in imports:
|
| 69 |
+
if not isinstance(imp, str):
|
| 70 |
+
raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
|
| 71 |
+
try:
|
| 72 |
+
imported_tmp = import_module(imp)
|
| 73 |
+
except ImportError:
|
| 74 |
+
if allow_failed_imports:
|
| 75 |
+
warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
|
| 76 |
+
imported_tmp = None
|
| 77 |
+
else:
|
| 78 |
+
raise ImportError
|
| 79 |
+
imported.append(imported_tmp)
|
| 80 |
+
if single_import:
|
| 81 |
+
imported = imported[0]
|
| 82 |
+
return imported
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _ntuple(n: int) -> Callable:
|
| 86 |
+
def parse(x: Any) -> Tuple:
|
| 87 |
+
if isinstance(x, Iterable) and not isinstance(x, str):
|
| 88 |
+
x = tuple(x)
|
| 89 |
+
if len(x) == 1:
|
| 90 |
+
x = tuple(repeat(x[0], n))
|
| 91 |
+
return x
|
| 92 |
+
return tuple(repeat(x, n))
|
| 93 |
+
|
| 94 |
+
return parse
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
to_1tuple = _ntuple(1)
|
| 98 |
+
to_2tuple = _ntuple(2)
|
| 99 |
+
to_3tuple = _ntuple(3)
|
| 100 |
+
to_4tuple = _ntuple(4)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def seconds_to_hmsms(seconds: float) -> tuple[int, int, int, int]:
|
| 104 |
+
hours, remainder = divmod(seconds, 3600)
|
| 105 |
+
minutes, remainder = divmod(remainder, 60)
|
| 106 |
+
seconds, milliseconds = divmod(remainder, 1)
|
| 107 |
+
milliseconds *= 1000
|
| 108 |
+
return int(hours), int(minutes), int(seconds), int(milliseconds)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def frames_to_hmsms(frames: int, frame_rate: int = 30) -> tuple[int, int, int, int]:
|
| 112 |
+
seconds = frames / frame_rate
|
| 113 |
+
return seconds_to_hmsms(seconds)
|
hymotion/utils/motion_process.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def smooth_quats(quats: np.ndarray, sigma: float = 1.0) -> np.ndarray:
|
| 9 |
+
from .geometry import gaussian_kernel1d, quaternion_fix_continuity, slice_seq_with_padding, wavg_quaternion_markley
|
| 10 |
+
|
| 11 |
+
if len(quats) == 0 or sigma <= 0:
|
| 12 |
+
return quats.copy()
|
| 13 |
+
|
| 14 |
+
q_all = quaternion_fix_continuity(torch.from_numpy(quats)).numpy()
|
| 15 |
+
|
| 16 |
+
results = q_all.copy()
|
| 17 |
+
truncate = 4.0
|
| 18 |
+
order = 0
|
| 19 |
+
lw = int(truncate * float(sigma) + 0.5)
|
| 20 |
+
weights = gaussian_kernel1d(sigma=sigma, order=order, radius=lw)[::-1]
|
| 21 |
+
kernel_len = len(weights)
|
| 22 |
+
|
| 23 |
+
for fr in range(len(q_all)):
|
| 24 |
+
cur_quats = slice_seq_with_padding(q_all, fr, kernel_len) # (K,4)
|
| 25 |
+
ref = cur_quats[kernel_len // 2 : kernel_len // 2 + 1] # (1,4)
|
| 26 |
+
dots = (cur_quats * ref).sum(axis=-1, keepdims=True) # (K,1)
|
| 27 |
+
cur_quats = np.where(dots < 0.0, -cur_quats, cur_quats)
|
| 28 |
+
|
| 29 |
+
results[fr, :] = wavg_quaternion_markley(cur_quats, weights)
|
| 30 |
+
|
| 31 |
+
return results.copy()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def smooth_rotation(
|
| 35 |
+
quats: np.ndarray,
|
| 36 |
+
# joint_names: List[str],
|
| 37 |
+
# smooth_joints: List[str],
|
| 38 |
+
sigma: float = 1.0,
|
| 39 |
+
) -> np.ndarray:
|
| 40 |
+
from .geometry import quaternion_fix_continuity
|
| 41 |
+
|
| 42 |
+
if quats.ndim == 4:
|
| 43 |
+
is_batch = True
|
| 44 |
+
else:
|
| 45 |
+
is_batch = False
|
| 46 |
+
quats = quats[None, ...]
|
| 47 |
+
for b in range(quats.shape[0]):
|
| 48 |
+
for j_idx in range(quats.shape[2]):
|
| 49 |
+
cur_quats = quats[b, :, j_idx].copy()
|
| 50 |
+
cur_quats_t = quaternion_fix_continuity(torch.from_numpy(cur_quats)).numpy()
|
| 51 |
+
quats[b, :, j_idx] = smooth_quats(cur_quats_t, sigma=sigma)
|
| 52 |
+
if not is_batch:
|
| 53 |
+
quats = quats.squeeze(0)
|
| 54 |
+
return quats
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def unwrap_euler_over_time(xyz: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
# xyz: (B, L, J, 3)
|
| 59 |
+
# y[t] = y[0] + cumsum(wrap(Δy))
|
| 60 |
+
y = xyz.clone()
|
| 61 |
+
dy = torch.atan2(torch.sin(y[:, 1:] - y[:, :-1]), torch.cos(y[:, 1:] - y[:, :-1]))
|
| 62 |
+
y[:, 1:] = y[:, :1] + torch.cumsum(dy, dim=1)
|
| 63 |
+
return y
|
hymotion/utils/path.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import platform
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Generator, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
from .misc import is_str
|
| 8 |
+
|
| 9 |
+
if platform.system() == "Windows":
|
| 10 |
+
import regex as re
|
| 11 |
+
else:
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
|
| 16 |
+
if not osp.isfile(filename):
|
| 17 |
+
raise FileNotFoundError(msg_tmpl.format(filename))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def mkdir_or_exist(dir_name: str, mode: int = 0o777) -> None:
|
| 21 |
+
if dir_name == "":
|
| 22 |
+
return
|
| 23 |
+
dir_name = osp.expanduser(dir_name)
|
| 24 |
+
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def symlink(src: str, dst: str, overwrite: bool = True, **kwargs) -> None:
|
| 28 |
+
if os.path.lexists(dst) and overwrite:
|
| 29 |
+
os.remove(dst)
|
| 30 |
+
os.symlink(src, dst, **kwargs)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_filepath(x: Any) -> bool:
|
| 34 |
+
return is_str(x) or isinstance(x, Path)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def scandir(
|
| 38 |
+
dir_path: Union[str, Path],
|
| 39 |
+
suffix: Optional[str] = None,
|
| 40 |
+
recursive: bool = False,
|
| 41 |
+
case_sensitive: bool = True,
|
| 42 |
+
) -> Generator[str, None, None]:
|
| 43 |
+
"""Scan a directory to find the interested files.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
dir_path (str | :obj:`Path`): Path of the directory.
|
| 47 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
| 48 |
+
interested in. Default: None.
|
| 49 |
+
recursive (bool, optional): If set to True, recursively scan the
|
| 50 |
+
directory. Default: False.
|
| 51 |
+
case_sensitive (bool, optional) : If set to False, ignore the case of
|
| 52 |
+
suffix. Default: True.
|
| 53 |
+
Returns:
|
| 54 |
+
A generator for all the interested files with relative paths.
|
| 55 |
+
"""
|
| 56 |
+
if isinstance(dir_path, (str, Path)):
|
| 57 |
+
dir_path = str(dir_path)
|
| 58 |
+
else:
|
| 59 |
+
raise TypeError('"dir_path" must be a string or Path object')
|
| 60 |
+
|
| 61 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 62 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
| 63 |
+
|
| 64 |
+
if suffix is not None and not case_sensitive:
|
| 65 |
+
suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix)
|
| 66 |
+
|
| 67 |
+
root = dir_path
|
| 68 |
+
|
| 69 |
+
def _scandir(
|
| 70 |
+
dir_path: Union[str, Path],
|
| 71 |
+
suffix: Optional[str],
|
| 72 |
+
recursive: bool,
|
| 73 |
+
case_sensitive: bool,
|
| 74 |
+
) -> Generator[str, None, None]:
|
| 75 |
+
for entry in os.scandir(dir_path):
|
| 76 |
+
if not entry.name.startswith(".") and entry.is_file():
|
| 77 |
+
rel_path = osp.relpath(entry.path, root)
|
| 78 |
+
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
| 79 |
+
if suffix is None or _rel_path.endswith(suffix):
|
| 80 |
+
yield rel_path
|
| 81 |
+
elif recursive and os.path.isdir(entry.path):
|
| 82 |
+
# scan recursively if entry.path is a directory
|
| 83 |
+
yield from _scandir(entry.path, suffix, recursive, case_sensitive)
|
| 84 |
+
|
| 85 |
+
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def find_files(directory, pattern, recursive=True, abspath=False) -> List[str]:
|
| 89 |
+
regex = re.compile(pattern)
|
| 90 |
+
file_list = []
|
| 91 |
+
for root, _, files in os.walk(directory):
|
| 92 |
+
for f in files:
|
| 93 |
+
if regex.match(f) is not None:
|
| 94 |
+
file_list.append(os.path.join(root, f))
|
| 95 |
+
if not recursive:
|
| 96 |
+
break
|
| 97 |
+
map_func = os.path.abspath if abspath else os.path.relpath
|
| 98 |
+
return list(map(map_func, sorted(file_list)))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def natural_keys(text: str, retoken: str = r"[a-zA-Z]*(\d+)[a-zA-Z_]*[\.].*", n: int = 1) -> Union[int, str]:
|
| 102 |
+
def _atoi(text: str) -> Union[int, str]:
|
| 103 |
+
return int(text) if text.isdigit() else text.lower()
|
| 104 |
+
|
| 105 |
+
return _atoi(re.split(retoken, text)[n])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
listdirs = lambda root: [osp.join(base, d) for base, dirs, _ in os.walk(root) if dirs for d in dirs]
|
| 109 |
+
|
| 110 |
+
listfiles = lambda root: [f for base, _, files in os.walk(root) if files for f in files]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def parse_dirs_and_sort(
|
| 114 |
+
input_dirs: Union[list, str],
|
| 115 |
+
suffix: str,
|
| 116 |
+
is_sort: bool = False,
|
| 117 |
+
with_prefix: bool = True,
|
| 118 |
+
) -> List[str]:
|
| 119 |
+
if isinstance(input_dirs, list):
|
| 120 |
+
input_dirs_list = []
|
| 121 |
+
for iter_input_dir in input_dirs:
|
| 122 |
+
if osp.isdir(iter_input_dir):
|
| 123 |
+
input_dirs_list += [
|
| 124 |
+
osp.join(iter_input_dir, x) if with_prefix else x
|
| 125 |
+
for x in scandir(
|
| 126 |
+
iter_input_dir,
|
| 127 |
+
suffix=suffix,
|
| 128 |
+
recursive=True,
|
| 129 |
+
case_sensitive=False,
|
| 130 |
+
)
|
| 131 |
+
]
|
| 132 |
+
elif osp.isfile(iter_input_dir):
|
| 133 |
+
if iter_input_dir.endswith(suffix):
|
| 134 |
+
input_dirs_list += [iter_input_dir]
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(f"Input path {iter_input_dir} is not exist.")
|
| 137 |
+
elif isinstance(input_dirs, str):
|
| 138 |
+
if osp.isdir(input_dirs):
|
| 139 |
+
input_dirs_list = [
|
| 140 |
+
osp.join(input_dirs, x) if with_prefix else x
|
| 141 |
+
for x in scandir(input_dirs, suffix=suffix, recursive=True, case_sensitive=False)
|
| 142 |
+
]
|
| 143 |
+
elif osp.isfile(input_dirs):
|
| 144 |
+
if input_dirs.endswith(suffix):
|
| 145 |
+
input_dirs_list = [input_dirs]
|
| 146 |
+
else:
|
| 147 |
+
input_dirs_list = []
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Input path {input_dirs} is not exist.")
|
| 150 |
+
else:
|
| 151 |
+
raise ValueError("Only support list or str input.")
|
| 152 |
+
|
| 153 |
+
if is_sort:
|
| 154 |
+
try:
|
| 155 |
+
try:
|
| 156 |
+
input_dirs_list = sorted(
|
| 157 |
+
input_dirs_list,
|
| 158 |
+
key=lambda text: (
|
| 159 |
+
natural_keys(text, retoken=r"[a-zA-Z]*(\d+)_[0-9a-zA-Z_]*[\.].*", n=1),
|
| 160 |
+
natural_keys(text, retoken=r"[0-9a-zA-Z]*_(\d+)[a-zA-Z_]*[\.].*", n=1),
|
| 161 |
+
),
|
| 162 |
+
)
|
| 163 |
+
except:
|
| 164 |
+
input_dirs_list = sorted(input_dirs_list, key=lambda text: (natural_keys(text)))
|
| 165 |
+
except:
|
| 166 |
+
input_dirs_list = sorted(input_dirs_list, key=lambda text: text)
|
| 167 |
+
|
| 168 |
+
return input_dirs_list
|
hymotion/utils/smplh2woodfbx.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
import fbx
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from transforms3d.euler import mat2euler
|
| 11 |
+
|
| 12 |
+
from .geometry import angle_axis_to_rotation_matrix, rot6d_to_rotation_matrix, rotation_matrix_to_angle_axis
|
| 13 |
+
|
| 14 |
+
# yapf: disable
|
| 15 |
+
SMPLH_JOINT2NUM = {
|
| 16 |
+
"Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3,
|
| 17 |
+
"L_Knee": 4, "R_Knee": 5, "Spine2": 6,
|
| 18 |
+
"L_Ankle": 7, "R_Ankle": 8,
|
| 19 |
+
"Spine3": 9,
|
| 20 |
+
"L_Foot": 10, "R_Foot": 11,
|
| 21 |
+
"Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15,
|
| 22 |
+
"L_Shoulder": 16, "R_Shoulder": 17,
|
| 23 |
+
"L_Elbow": 18, "R_Elbow": 19,
|
| 24 |
+
"L_Wrist": 20, "R_Wrist": 21,
|
| 25 |
+
"L_Index1": 22, "L_Index2": 23, "L_Index3": 24,
|
| 26 |
+
"L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27,
|
| 27 |
+
"L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30,
|
| 28 |
+
"L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33,
|
| 29 |
+
"L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36,
|
| 30 |
+
"R_Index1": 37, "R_Index2": 38, "R_Index3": 39,
|
| 31 |
+
"R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42,
|
| 32 |
+
"R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45,
|
| 33 |
+
"R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48,
|
| 34 |
+
"R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
# Mapping from SMPL-H joint names to lowercase names used in some FBX templates
|
| 38 |
+
SMPLH_TO_LOWERCASE_MAPPING = {
|
| 39 |
+
"Pelvis": "pelvis",
|
| 40 |
+
"L_Hip": "left_hip",
|
| 41 |
+
"R_Hip": "right_hip",
|
| 42 |
+
"Spine1": "spine1",
|
| 43 |
+
"L_Knee": "left_knee",
|
| 44 |
+
"R_Knee": "right_knee",
|
| 45 |
+
"Spine2": "spine2",
|
| 46 |
+
"L_Ankle": "left_ankle",
|
| 47 |
+
"R_Ankle": "right_ankle",
|
| 48 |
+
"Spine3": "spine3",
|
| 49 |
+
"L_Foot": "left_foot",
|
| 50 |
+
"R_Foot": "right_foot",
|
| 51 |
+
"Neck": "neck",
|
| 52 |
+
"L_Collar": "left_collar",
|
| 53 |
+
"R_Collar": "right_collar",
|
| 54 |
+
"Head": "head",
|
| 55 |
+
"L_Shoulder": "left_shoulder",
|
| 56 |
+
"R_Shoulder": "right_shoulder",
|
| 57 |
+
"L_Elbow": "left_elbow",
|
| 58 |
+
"R_Elbow": "right_elbow",
|
| 59 |
+
"L_Wrist": "left_wrist",
|
| 60 |
+
"R_Wrist": "right_wrist",
|
| 61 |
+
"L_Index1": "left_index1",
|
| 62 |
+
"L_Index2": "left_index2",
|
| 63 |
+
"L_Index3": "left_index3",
|
| 64 |
+
"L_Middle1": "left_middle1",
|
| 65 |
+
"L_Middle2": "left_middle2",
|
| 66 |
+
"L_Middle3": "left_middle3",
|
| 67 |
+
"L_Pinky1": "left_pinky1",
|
| 68 |
+
"L_Pinky2": "left_pinky2",
|
| 69 |
+
"L_Pinky3": "left_pinky3",
|
| 70 |
+
"L_Ring1": "left_ring1",
|
| 71 |
+
"L_Ring2": "left_ring2",
|
| 72 |
+
"L_Ring3": "left_ring3",
|
| 73 |
+
"L_Thumb1": "left_thumb1",
|
| 74 |
+
"L_Thumb2": "left_thumb2",
|
| 75 |
+
"L_Thumb3": "left_thumb3",
|
| 76 |
+
"R_Index1": "right_index1",
|
| 77 |
+
"R_Index2": "right_index2",
|
| 78 |
+
"R_Index3": "right_index3",
|
| 79 |
+
"R_Middle1": "right_middle1",
|
| 80 |
+
"R_Middle2": "right_middle2",
|
| 81 |
+
"R_Middle3": "right_middle3",
|
| 82 |
+
"R_Pinky1": "right_pinky1",
|
| 83 |
+
"R_Pinky2": "right_pinky2",
|
| 84 |
+
"R_Pinky3": "right_pinky3",
|
| 85 |
+
"R_Ring1": "right_ring1",
|
| 86 |
+
"R_Ring2": "right_ring2",
|
| 87 |
+
"R_Ring3": "right_ring3",
|
| 88 |
+
"R_Thumb1": "right_thumb1",
|
| 89 |
+
"R_Thumb2": "right_thumb2",
|
| 90 |
+
"R_Thumb3": "right_thumb3",
|
| 91 |
+
}
|
| 92 |
+
# yapf: enable
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _loadFbxScene(fbxManager, filepath):
|
| 96 |
+
"""Load an FBX file into a scene"""
|
| 97 |
+
importer = fbx.FbxImporter.Create(fbxManager, "")
|
| 98 |
+
|
| 99 |
+
if not importer.Initialize(filepath, -1, fbxManager.GetIOSettings()):
|
| 100 |
+
raise Exception(
|
| 101 |
+
f"Failed to initialize FBX importer for: {filepath}\nError: {importer.GetStatus().GetErrorString()}"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
fbxScene = fbx.FbxScene.Create(fbxManager, "")
|
| 105 |
+
importer.Import(fbxScene)
|
| 106 |
+
importer.Destroy()
|
| 107 |
+
|
| 108 |
+
return fbxScene
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _collectAllNodes(node, nodes_dict=None):
|
| 112 |
+
"""Recursively collect all nodes in the scene hierarchy"""
|
| 113 |
+
if nodes_dict is None:
|
| 114 |
+
nodes_dict = {}
|
| 115 |
+
|
| 116 |
+
nodes_dict[node.GetName()] = node
|
| 117 |
+
|
| 118 |
+
for i in range(node.GetChildCount()):
|
| 119 |
+
_collectAllNodes(node.GetChild(i), nodes_dict)
|
| 120 |
+
|
| 121 |
+
return nodes_dict
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _collectSkeletonNodes(node, skeleton_nodes=None):
|
| 125 |
+
"""Recursively collect skeleton/bone nodes"""
|
| 126 |
+
if skeleton_nodes is None:
|
| 127 |
+
skeleton_nodes = {}
|
| 128 |
+
|
| 129 |
+
# Check if this node has a skeleton attribute
|
| 130 |
+
attr = node.GetNodeAttribute()
|
| 131 |
+
if attr and attr.GetAttributeType() == fbx.FbxNodeAttribute.EType.eSkeleton:
|
| 132 |
+
skeleton_nodes[node.GetName()] = node
|
| 133 |
+
|
| 134 |
+
for i in range(node.GetChildCount()):
|
| 135 |
+
_collectSkeletonNodes(node.GetChild(i), skeleton_nodes)
|
| 136 |
+
|
| 137 |
+
return skeleton_nodes
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _animateSingleChannel(animLayer, component, name, values, frameDuration):
|
| 141 |
+
"""Animate a single channel (X, Y, or Z) with keyframes"""
|
| 142 |
+
ncomp = {"X": 0, "Y": 1, "Z": 2}.get(name, 0)
|
| 143 |
+
|
| 144 |
+
time = fbx.FbxTime()
|
| 145 |
+
curve = component.GetCurve(animLayer, name, True)
|
| 146 |
+
curve.KeyModifyBegin()
|
| 147 |
+
for nth in range(len(values)):
|
| 148 |
+
time.SetSecondDouble(nth * frameDuration)
|
| 149 |
+
keyIndex = curve.KeyAdd(time)[0]
|
| 150 |
+
curve.KeySetValue(keyIndex, values[nth][ncomp])
|
| 151 |
+
curve.KeySetInterpolation(keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant)
|
| 152 |
+
curve.KeyModifyEnd()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _animateRotationKeyFrames(animLayer, node, rot_matrices, frameDuration):
|
| 156 |
+
"""Animate rotation keyframes for a node using rotation matrices"""
|
| 157 |
+
rotations = []
|
| 158 |
+
for nth in range(len(rot_matrices)):
|
| 159 |
+
# Convert rotation matrix to Euler angles (XYZ order)
|
| 160 |
+
euler = np.rad2deg(mat2euler(rot_matrices[nth], axes="sxyz"))
|
| 161 |
+
rotations.append(euler)
|
| 162 |
+
|
| 163 |
+
_animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration)
|
| 164 |
+
_animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration)
|
| 165 |
+
_animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _animateTranslationKeyFrames(animLayer, node, translations, frameDuration):
|
| 169 |
+
"""Animate translation keyframes for a node"""
|
| 170 |
+
# Ensure translations is a numpy array with shape (num_frames, 3)
|
| 171 |
+
if isinstance(translations, torch.Tensor):
|
| 172 |
+
translations = translations.numpy()
|
| 173 |
+
translations = np.asarray(translations, dtype=np.float64)
|
| 174 |
+
|
| 175 |
+
if len(translations.shape) == 1:
|
| 176 |
+
# Single frame, reshape to (1, 3)
|
| 177 |
+
translations = translations.reshape(1, -1)
|
| 178 |
+
|
| 179 |
+
_animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration)
|
| 180 |
+
_animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration)
|
| 181 |
+
_animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _clearExistingAnimations(fbxScene):
|
| 185 |
+
"""Remove all existing animation stacks from the scene"""
|
| 186 |
+
anim_stack_count = fbxScene.GetSrcObjectCount(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId))
|
| 187 |
+
for i in range(anim_stack_count - 1, -1, -1):
|
| 188 |
+
anim_stack = fbxScene.GetSrcObject(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId), i)
|
| 189 |
+
if anim_stack:
|
| 190 |
+
anim_stack.Destroy()
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _applyAnimationToSkeleton(fbxScene, nodes_map, rot_matrices, translations, fps, smplh_to_fbx_mapping, name="Take1"):
|
| 194 |
+
"""
|
| 195 |
+
Apply SMPL-H animation data to skeleton nodes in the FBX scene.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
fbxScene: FBX scene object
|
| 199 |
+
nodes_map: Dictionary of node_name -> FbxNode
|
| 200 |
+
rot_matrices: (num_frames, num_joints, 3, 3) rotation matrices
|
| 201 |
+
translations: (num_frames, 3) root translations (relative displacement, not absolute position)
|
| 202 |
+
fps: Frame rate
|
| 203 |
+
smplh_to_fbx_mapping: Mapping from SMPL-H joint names to FBX node names
|
| 204 |
+
name: Animation take name
|
| 205 |
+
"""
|
| 206 |
+
frameDuration = 1.0 / fps
|
| 207 |
+
num_frames = rot_matrices.shape[0]
|
| 208 |
+
num_joints = rot_matrices.shape[1]
|
| 209 |
+
|
| 210 |
+
# Create animation stack and layer
|
| 211 |
+
animStack = fbx.FbxAnimStack.Create(fbxScene, name)
|
| 212 |
+
animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer")
|
| 213 |
+
animStack.AddMember(animLayer)
|
| 214 |
+
|
| 215 |
+
# Track if root translation was applied
|
| 216 |
+
root_translation_applied = False
|
| 217 |
+
root_node = None
|
| 218 |
+
|
| 219 |
+
# Get root node's initial LclTranslation from template (this is like Translates[0] in smplh2woodfbx.py)
|
| 220 |
+
root_initial_translation = None
|
| 221 |
+
root_fbx_name = smplh_to_fbx_mapping.get("Pelvis")
|
| 222 |
+
if root_fbx_name and root_fbx_name in nodes_map:
|
| 223 |
+
root_node_temp = nodes_map[root_fbx_name]
|
| 224 |
+
initial_trans = root_node_temp.LclTranslation.Get()
|
| 225 |
+
root_initial_translation = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
|
| 226 |
+
print(f"Root initial LclTranslation from template: {root_initial_translation}")
|
| 227 |
+
|
| 228 |
+
# Animate each joint
|
| 229 |
+
for smplh_joint_name, smplh_joint_idx in SMPLH_JOINT2NUM.items():
|
| 230 |
+
if smplh_joint_idx >= num_joints:
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
+
# Get the FBX node name from mapping
|
| 234 |
+
fbx_node_name = smplh_to_fbx_mapping.get(smplh_joint_name)
|
| 235 |
+
if not fbx_node_name:
|
| 236 |
+
if smplh_joint_idx == 0:
|
| 237 |
+
print(f"Warning: Root joint 'Pelvis' not found in mapping!")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# Find the node
|
| 241 |
+
node = nodes_map.get(fbx_node_name)
|
| 242 |
+
if not node:
|
| 243 |
+
print(f"Warning: Joint '{smplh_joint_name}' (FBX: '{fbx_node_name}') not found in scene")
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
# Animate rotation
|
| 247 |
+
_animateRotationKeyFrames(
|
| 248 |
+
animLayer=animLayer,
|
| 249 |
+
node=node,
|
| 250 |
+
rot_matrices=rot_matrices[:, smplh_joint_idx],
|
| 251 |
+
frameDuration=frameDuration,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Animate translation for root joint (Pelvis)
|
| 255 |
+
if smplh_joint_idx == 0:
|
| 256 |
+
root_node = node
|
| 257 |
+
# Add initial offset to translations (like smplh2woodfbx.py does: Translates[0] + trans)
|
| 258 |
+
# The translations input is relative displacement, we need to add the template's initial position
|
| 259 |
+
if root_initial_translation is not None:
|
| 260 |
+
final_translations = translations + root_initial_translation
|
| 261 |
+
print(
|
| 262 |
+
f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
|
| 263 |
+
f"initial_offset={root_initial_translation}, "
|
| 264 |
+
f"final translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
final_translations = translations
|
| 268 |
+
print(
|
| 269 |
+
f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
|
| 270 |
+
f"translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
|
| 271 |
+
)
|
| 272 |
+
_animateTranslationKeyFrames(
|
| 273 |
+
animLayer=animLayer,
|
| 274 |
+
node=node,
|
| 275 |
+
translations=final_translations,
|
| 276 |
+
frameDuration=frameDuration,
|
| 277 |
+
)
|
| 278 |
+
root_translation_applied = True
|
| 279 |
+
|
| 280 |
+
# If root translation was not applied, try to find root node by common names
|
| 281 |
+
if not root_translation_applied:
|
| 282 |
+
print("Warning: Root translation was not applied through normal mapping, trying fallback...")
|
| 283 |
+
root_candidates = ["Pelvis", "pelvis", "Hips", "hips", "Root", "root", "mixamorig:Hips"]
|
| 284 |
+
for candidate in root_candidates:
|
| 285 |
+
if candidate in nodes_map:
|
| 286 |
+
root_node = nodes_map[candidate]
|
| 287 |
+
# Get initial translation for fallback node
|
| 288 |
+
initial_trans = root_node.LclTranslation.Get()
|
| 289 |
+
fallback_initial = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
|
| 290 |
+
final_translations = translations + fallback_initial
|
| 291 |
+
print(
|
| 292 |
+
f"Found root node by fallback: '{candidate}', initial_offset={fallback_initial}, applying translation..."
|
| 293 |
+
)
|
| 294 |
+
_animateTranslationKeyFrames(
|
| 295 |
+
animLayer=animLayer,
|
| 296 |
+
node=root_node,
|
| 297 |
+
translations=final_translations,
|
| 298 |
+
frameDuration=frameDuration,
|
| 299 |
+
)
|
| 300 |
+
root_translation_applied = True
|
| 301 |
+
break
|
| 302 |
+
|
| 303 |
+
if not root_translation_applied:
|
| 304 |
+
print("ERROR: Could not find root node to apply translation!")
|
| 305 |
+
print(f"Available nodes: {list(nodes_map.keys())}")
|
| 306 |
+
|
| 307 |
+
return animStack
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _saveScene(filename, fbxManager, fbxScene, embed_textures=True):
|
| 311 |
+
"""Save the FBX scene to a file
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
filename: Output file path
|
| 315 |
+
fbxManager: FBX manager instance
|
| 316 |
+
fbxScene: FBX scene to save
|
| 317 |
+
embed_textures: Whether to embed textures/media in the FBX file (default True)
|
| 318 |
+
"""
|
| 319 |
+
# Configure IOSettings to embed textures/media
|
| 320 |
+
ios = fbxManager.GetIOSettings()
|
| 321 |
+
if embed_textures:
|
| 322 |
+
ios.SetBoolProp(fbx.EXP_FBX_EMBEDDED, True)
|
| 323 |
+
ios.SetBoolProp(fbx.EXP_FBX_MATERIAL, True)
|
| 324 |
+
ios.SetBoolProp(fbx.EXP_FBX_TEXTURE, True)
|
| 325 |
+
|
| 326 |
+
exporter = fbx.FbxExporter.Create(fbxManager, "")
|
| 327 |
+
isInitialized = exporter.Initialize(filename, -1, ios)
|
| 328 |
+
|
| 329 |
+
if isInitialized is False:
|
| 330 |
+
raise Exception(f"Exporter failed to initialize. Error: {exporter.GetStatus().GetErrorString()}")
|
| 331 |
+
|
| 332 |
+
exporter.Export(fbxScene)
|
| 333 |
+
exporter.Destroy()
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def _convert_smplh_to_woodfbx(
|
| 337 |
+
template_fbx_path,
|
| 338 |
+
npz_data,
|
| 339 |
+
save_fn,
|
| 340 |
+
fps=30,
|
| 341 |
+
scale=100,
|
| 342 |
+
smplh_to_fbx_mapping=None,
|
| 343 |
+
clear_animations=True,
|
| 344 |
+
):
|
| 345 |
+
"""
|
| 346 |
+
Convert SMPL-H parameters to FBX using a template FBX file.
|
| 347 |
+
The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
template_fbx_path: Path to the template FBX file (e.g., boy_Rigging_smplx.fbx)
|
| 351 |
+
npz_data: Dictionary containing SMPL-H parameters
|
| 352 |
+
- poses: (num_frames, 52, 3) or (num_frames, 156)
|
| 353 |
+
- trans: (num_frames, 3)
|
| 354 |
+
save_fn: Output FBX file path
|
| 355 |
+
fps: Frame rate
|
| 356 |
+
scale: Scale factor for translation (default 100 for m to cm conversion)
|
| 357 |
+
smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
|
| 358 |
+
clear_animations: Whether to clear existing animations in the template
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
bool: True if successful
|
| 362 |
+
"""
|
| 363 |
+
# Prepare poses data
|
| 364 |
+
poses = npz_data["poses"]
|
| 365 |
+
if isinstance(poses, np.ndarray):
|
| 366 |
+
poses = torch.from_numpy(poses).float()
|
| 367 |
+
|
| 368 |
+
if len(poses.shape) == 2:
|
| 369 |
+
# (num_frames, 156) -> (num_frames, 52, 3)
|
| 370 |
+
poses = poses.reshape(poses.shape[0], -1, 3)
|
| 371 |
+
|
| 372 |
+
# Convert axis-angle to rotation matrices: (num_frames, num_joints, 3, 3)
|
| 373 |
+
rot_matrices = angle_axis_to_rotation_matrix(poses).numpy()
|
| 374 |
+
|
| 375 |
+
# Prepare translation data
|
| 376 |
+
trans = npz_data["trans"]
|
| 377 |
+
if isinstance(trans, torch.Tensor):
|
| 378 |
+
trans = trans.numpy()
|
| 379 |
+
|
| 380 |
+
# Apply scale to translation
|
| 381 |
+
translations = trans * scale
|
| 382 |
+
|
| 383 |
+
# Create FBX manager and load template
|
| 384 |
+
fbxManager = fbx.FbxManager.Create()
|
| 385 |
+
ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
|
| 386 |
+
fbxManager.SetIOSettings(ios)
|
| 387 |
+
|
| 388 |
+
print(f"Loading FBX template: {template_fbx_path}")
|
| 389 |
+
fbxScene = _loadFbxScene(fbxManager, template_fbx_path)
|
| 390 |
+
|
| 391 |
+
# Set time mode
|
| 392 |
+
timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps)
|
| 393 |
+
fbxScene.GetGlobalSettings().SetTimeMode(timeMode)
|
| 394 |
+
|
| 395 |
+
# Collect all nodes
|
| 396 |
+
rootNode = fbxScene.GetRootNode()
|
| 397 |
+
all_nodes = _collectAllNodes(rootNode)
|
| 398 |
+
skeleton_nodes = _collectSkeletonNodes(rootNode)
|
| 399 |
+
|
| 400 |
+
print(f"Found {len(all_nodes)} nodes in scene")
|
| 401 |
+
print(f"Found {len(skeleton_nodes)} skeleton nodes: {list(skeleton_nodes.keys())}")
|
| 402 |
+
|
| 403 |
+
# Use default mapping if not provided
|
| 404 |
+
if smplh_to_fbx_mapping is None:
|
| 405 |
+
smplh_to_fbx_mapping = _auto_detect_mapping(all_nodes)
|
| 406 |
+
print(f"Auto-detected {len(smplh_to_fbx_mapping)} joint mappings")
|
| 407 |
+
if "Pelvis" in smplh_to_fbx_mapping:
|
| 408 |
+
print(f" Root joint 'Pelvis' mapped to: '{smplh_to_fbx_mapping['Pelvis']}'")
|
| 409 |
+
else:
|
| 410 |
+
print(f" WARNING: Root joint 'Pelvis' not found in mapping!")
|
| 411 |
+
print(f" Available nodes: {list(all_nodes.keys())[:20]}...") # Show first 20 nodes
|
| 412 |
+
|
| 413 |
+
# Clear existing animations if requested
|
| 414 |
+
if clear_animations:
|
| 415 |
+
_clearExistingAnimations(fbxScene)
|
| 416 |
+
|
| 417 |
+
# Apply animation to skeleton
|
| 418 |
+
_applyAnimationToSkeleton(
|
| 419 |
+
fbxScene=fbxScene,
|
| 420 |
+
nodes_map=all_nodes,
|
| 421 |
+
rot_matrices=rot_matrices,
|
| 422 |
+
translations=translations,
|
| 423 |
+
fps=fps,
|
| 424 |
+
smplh_to_fbx_mapping=smplh_to_fbx_mapping,
|
| 425 |
+
name="SMPLH_Animation",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Save to temporary file first, then copy to final destination
|
| 429 |
+
os.makedirs(os.path.dirname(save_fn) if os.path.dirname(save_fn) else ".", exist_ok=True)
|
| 430 |
+
with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f:
|
| 431 |
+
temp_file = tmp_f.name
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
_saveScene(temp_file, fbxManager, fbxScene)
|
| 435 |
+
shutil.copy2(temp_file, save_fn)
|
| 436 |
+
os.remove(temp_file)
|
| 437 |
+
print(f"Successfully saved FBX to: {save_fn}")
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Error saving FBX file: {e}")
|
| 440 |
+
return False
|
| 441 |
+
finally:
|
| 442 |
+
fbxManager.Destroy()
|
| 443 |
+
del fbxManager, fbxScene
|
| 444 |
+
|
| 445 |
+
return os.path.exists(save_fn)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _auto_detect_mapping(all_nodes):
|
| 449 |
+
"""Auto-detect the mapping from SMPL-H joints to FBX nodes"""
|
| 450 |
+
mapping = {}
|
| 451 |
+
for smplh_name in SMPLH_JOINT2NUM.keys():
|
| 452 |
+
# Try exact match
|
| 453 |
+
if smplh_name in all_nodes:
|
| 454 |
+
mapping[smplh_name] = smplh_name
|
| 455 |
+
# Try lowercase version
|
| 456 |
+
elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in all_nodes:
|
| 457 |
+
mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
|
| 458 |
+
return mapping
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class SMPLH2WoodFBX:
|
| 462 |
+
"""
|
| 463 |
+
Class to convert SMPL-H parameters to FBX using a template FBX file.
|
| 464 |
+
The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
|
| 465 |
+
No SMPL-H model assets (model.npz) required.
|
| 466 |
+
|
| 467 |
+
Example usage:
|
| 468 |
+
converter = SMPLH2WoodFBX(
|
| 469 |
+
template_fbx_path="./assets/wooden_models/boy_Rigging_smplx.fbx"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
# From npz file
|
| 473 |
+
converter.convert_npz_to_fbx("motion.npz", "output.fbx", fps=30)
|
| 474 |
+
|
| 475 |
+
# From parameters dict
|
| 476 |
+
params = {
|
| 477 |
+
"poses": poses_array, # (num_frames, 52, 3) or (num_frames, 156)
|
| 478 |
+
"trans": trans_array, # (num_frames, 3)
|
| 479 |
+
}
|
| 480 |
+
converter.convert_params_to_fbx(params, "output.fbx")
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
def __init__(
|
| 484 |
+
self,
|
| 485 |
+
template_fbx_path: str = "./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
|
| 486 |
+
smplh_to_fbx_mapping: Optional[Dict[str, str]] = None,
|
| 487 |
+
scale: float = 100,
|
| 488 |
+
):
|
| 489 |
+
"""
|
| 490 |
+
Initialize the converter.
|
| 491 |
+
|
| 492 |
+
Args:
|
| 493 |
+
template_fbx_path: Path to the template FBX file
|
| 494 |
+
smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
|
| 495 |
+
scale: Scale factor for translation (default 100 for m to cm conversion)
|
| 496 |
+
"""
|
| 497 |
+
print(f"[{self.__class__.__name__}] Template FBX: {template_fbx_path}")
|
| 498 |
+
self.template_fbx_path = template_fbx_path
|
| 499 |
+
self.smplh_to_fbx_mapping = smplh_to_fbx_mapping
|
| 500 |
+
self.scale = scale
|
| 501 |
+
|
| 502 |
+
# Analyze template FBX to detect joint names
|
| 503 |
+
self._analyze_template()
|
| 504 |
+
|
| 505 |
+
def _analyze_template(self):
|
| 506 |
+
"""Analyze the template FBX file to detect available skeleton nodes"""
|
| 507 |
+
fbxManager = fbx.FbxManager.Create()
|
| 508 |
+
ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
|
| 509 |
+
fbxManager.SetIOSettings(ios)
|
| 510 |
+
|
| 511 |
+
try:
|
| 512 |
+
fbxScene = _loadFbxScene(fbxManager, self.template_fbx_path)
|
| 513 |
+
rootNode = fbxScene.GetRootNode()
|
| 514 |
+
|
| 515 |
+
self.all_template_nodes = list(_collectAllNodes(rootNode).keys())
|
| 516 |
+
self.skeleton_template_nodes = list(_collectSkeletonNodes(rootNode).keys())
|
| 517 |
+
|
| 518 |
+
print(f"[{self.__class__.__name__}] Template nodes: {len(self.all_template_nodes)}")
|
| 519 |
+
print(f"[{self.__class__.__name__}] Skeleton nodes: {self.skeleton_template_nodes}")
|
| 520 |
+
|
| 521 |
+
# Auto-detect mapping if not provided
|
| 522 |
+
if self.smplh_to_fbx_mapping is None:
|
| 523 |
+
self.smplh_to_fbx_mapping = self._auto_detect_mapping()
|
| 524 |
+
print(f"[{self.__class__.__name__}] Auto-detected {len(self.smplh_to_fbx_mapping)} joint mappings")
|
| 525 |
+
finally:
|
| 526 |
+
fbxManager.Destroy()
|
| 527 |
+
|
| 528 |
+
def _auto_detect_mapping(self):
|
| 529 |
+
"""Auto-detect the mapping from SMPL-H joints to FBX nodes"""
|
| 530 |
+
mapping = {}
|
| 531 |
+
for smplh_name in SMPLH_JOINT2NUM.keys():
|
| 532 |
+
# Try exact match
|
| 533 |
+
if smplh_name in self.all_template_nodes:
|
| 534 |
+
mapping[smplh_name] = smplh_name
|
| 535 |
+
# Try lowercase version
|
| 536 |
+
elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in self.all_template_nodes:
|
| 537 |
+
mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
|
| 538 |
+
return mapping
|
| 539 |
+
|
| 540 |
+
def convert_npz_to_fbx(self, npz_file, outname, fps=30, clear_animations=True):
|
| 541 |
+
"""
|
| 542 |
+
Convert an npz file containing SMPL-H parameters to FBX.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
npz_file: Path to the npz file or dict containing SMPL-H parameters
|
| 546 |
+
outname: Output FBX file path
|
| 547 |
+
fps: Frame rate
|
| 548 |
+
clear_animations: Whether to clear existing animations in template
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
bool: True if successful
|
| 552 |
+
"""
|
| 553 |
+
os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
|
| 554 |
+
|
| 555 |
+
if isinstance(npz_file, str) and os.path.isfile(npz_file):
|
| 556 |
+
npz_data = dict(np.load(npz_file, allow_pickle=True))
|
| 557 |
+
else:
|
| 558 |
+
npz_data = npz_file
|
| 559 |
+
|
| 560 |
+
return _convert_smplh_to_woodfbx(
|
| 561 |
+
template_fbx_path=self.template_fbx_path,
|
| 562 |
+
npz_data=npz_data,
|
| 563 |
+
save_fn=outname,
|
| 564 |
+
fps=fps,
|
| 565 |
+
scale=self.scale,
|
| 566 |
+
smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
|
| 567 |
+
clear_animations=clear_animations,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
def convert_params_to_fbx(self, params, outname, clear_animations=True):
|
| 571 |
+
"""
|
| 572 |
+
Convert SMPL-H parameters to FBX.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
params: Dictionary containing SMPL-H parameters
|
| 576 |
+
- poses: (num_frames, 52, 3) or (num_frames, 156)
|
| 577 |
+
- trans: (num_frames, 3)
|
| 578 |
+
- mocap_framerate (optional): Frame rate
|
| 579 |
+
outname: Output FBX file path
|
| 580 |
+
clear_animations: Whether to clear existing animations in template
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
bool: True if successful
|
| 584 |
+
"""
|
| 585 |
+
fps = params.get("mocap_framerate", 30)
|
| 586 |
+
os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
|
| 587 |
+
|
| 588 |
+
npz_data = {
|
| 589 |
+
"poses": params["poses"],
|
| 590 |
+
"trans": params["trans"],
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
return _convert_smplh_to_woodfbx(
|
| 594 |
+
template_fbx_path=self.template_fbx_path,
|
| 595 |
+
npz_data=npz_data,
|
| 596 |
+
save_fn=outname,
|
| 597 |
+
fps=fps,
|
| 598 |
+
scale=self.scale,
|
| 599 |
+
smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
|
| 600 |
+
clear_animations=clear_animations,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
if __name__ == "__main__":
|
| 605 |
+
# python hymotion/utils/smplh2woodfbx.py
|
| 606 |
+
import argparse
|
| 607 |
+
|
| 608 |
+
parser = argparse.ArgumentParser()
|
| 609 |
+
parser.add_argument("root", type=str)
|
| 610 |
+
args = parser.parse_args()
|
| 611 |
+
|
| 612 |
+
converter = SMPLH2WoodFBX(
|
| 613 |
+
template_fbx_path="./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
|
| 614 |
+
scale=100,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
if os.path.isdir(args.root):
|
| 618 |
+
npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz")))
|
| 619 |
+
else:
|
| 620 |
+
if args.root.endswith(".npz"):
|
| 621 |
+
npzfiles = [args.root]
|
| 622 |
+
else:
|
| 623 |
+
raise ValueError(f"Unknown file type: {args.root}")
|
| 624 |
+
|
| 625 |
+
for npzfile in npzfiles:
|
| 626 |
+
converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx"))
|
hymotion/utils/t2m_runtime.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# t2m_runtime.py
|
| 2 |
+
import os
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
import uuid
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
from ..prompt_engineering.prompt_rewrite import PromptRewriter
|
| 12 |
+
from .loaders import load_object
|
| 13 |
+
from .visualize_mesh_web import save_visualization_data, generate_static_html_content
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import fbx
|
| 17 |
+
|
| 18 |
+
FBX_AVAILABLE = True
|
| 19 |
+
print(">>> FBX module found.")
|
| 20 |
+
except ImportError:
|
| 21 |
+
FBX_AVAILABLE = False
|
| 22 |
+
print(">>> FBX module not found.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_local_ip():
|
| 26 |
+
import subprocess
|
| 27 |
+
|
| 28 |
+
result = subprocess.run(["hostname", "-I"], capture_output=True, text=True, timeout=5)
|
| 29 |
+
if result.returncode == 0:
|
| 30 |
+
for ip in result.stdout.strip().split():
|
| 31 |
+
if not ip.startswith("127.") and not ip.startswith("172.17."):
|
| 32 |
+
return ip
|
| 33 |
+
return "localhost"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _now():
|
| 37 |
+
t = time.time()
|
| 38 |
+
ms = int((t - int(t)) * 1000)
|
| 39 |
+
return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class T2MRuntime:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
config_path: str,
|
| 46 |
+
ckpt_name: str = "latest.ckpt",
|
| 47 |
+
skip_text: bool = False,
|
| 48 |
+
device_ids: Union[list[int], None] = None,
|
| 49 |
+
prompt_engineering_host: Optional[str] = None,
|
| 50 |
+
skip_model_loading: bool = False,
|
| 51 |
+
force_cpu: bool = False,
|
| 52 |
+
):
|
| 53 |
+
self.config_path = config_path
|
| 54 |
+
self.ckpt_name = ckpt_name
|
| 55 |
+
self.skip_text = skip_text
|
| 56 |
+
self.prompt_engineering_host = prompt_engineering_host
|
| 57 |
+
self.skip_model_loading = skip_model_loading
|
| 58 |
+
self.local_ip = _get_local_ip()
|
| 59 |
+
|
| 60 |
+
if force_cpu:
|
| 61 |
+
print(">>> [INFO] CPU mode enabled via HY_MOTION_DEVICE=cpu environment variable")
|
| 62 |
+
self.device_ids = []
|
| 63 |
+
elif torch.cuda.is_available():
|
| 64 |
+
all_ids = list(range(torch.cuda.device_count()))
|
| 65 |
+
self.device_ids = all_ids if device_ids is None else [i for i in device_ids if i in all_ids]
|
| 66 |
+
else:
|
| 67 |
+
self.device_ids = []
|
| 68 |
+
|
| 69 |
+
self.pipelines = []
|
| 70 |
+
self._gpu_load = []
|
| 71 |
+
self._lock = threading.Lock()
|
| 72 |
+
self._loaded = False
|
| 73 |
+
|
| 74 |
+
self.prompt_rewriter = PromptRewriter(host=self.prompt_engineering_host)
|
| 75 |
+
# Skip model loading if checkpoint not found
|
| 76 |
+
if self.skip_model_loading:
|
| 77 |
+
print(">>> [WARNING] Checkpoint not found, will use randomly initialized model weights")
|
| 78 |
+
self.load()
|
| 79 |
+
self.fbx_available = FBX_AVAILABLE
|
| 80 |
+
if self.fbx_available:
|
| 81 |
+
try:
|
| 82 |
+
from .smplh2woodfbx import SMPLH2WoodFBX
|
| 83 |
+
|
| 84 |
+
self.fbx_converter = SMPLH2WoodFBX()
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f">>> Failed to initialize FBX converter: {e}")
|
| 87 |
+
self.fbx_available = False
|
| 88 |
+
self.fbx_converter = None
|
| 89 |
+
else:
|
| 90 |
+
self.fbx_converter = None
|
| 91 |
+
print(">>> FBX module not found. FBX export will be disabled.")
|
| 92 |
+
|
| 93 |
+
device_info = self.device_ids if self.device_ids else "cpu"
|
| 94 |
+
if self.skip_model_loading:
|
| 95 |
+
print(f">>> T2MRuntime initialized (using randomly initialized weights) in IP {self.local_ip}, devices={device_info}")
|
| 96 |
+
else:
|
| 97 |
+
print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}")
|
| 98 |
+
|
| 99 |
+
def load(self):
|
| 100 |
+
if self._loaded:
|
| 101 |
+
return
|
| 102 |
+
print(f">>> Loading model from {self.config_path}...")
|
| 103 |
+
|
| 104 |
+
with open(self.config_path, "r") as f:
|
| 105 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 106 |
+
|
| 107 |
+
# Use allow_empty_ckpt=True when skip_model_loading is True
|
| 108 |
+
allow_empty_ckpt = self.skip_model_loading
|
| 109 |
+
|
| 110 |
+
if not self.device_ids:
|
| 111 |
+
pipeline = load_object(
|
| 112 |
+
config["train_pipeline"],
|
| 113 |
+
config["train_pipeline_args"],
|
| 114 |
+
network_module=config["network_module"],
|
| 115 |
+
network_module_args=config["network_module_args"],
|
| 116 |
+
)
|
| 117 |
+
device = torch.device("cpu")
|
| 118 |
+
pipeline.load_in_demo(
|
| 119 |
+
self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt
|
| 120 |
+
)
|
| 121 |
+
pipeline.to(device)
|
| 122 |
+
self.pipelines = [pipeline]
|
| 123 |
+
self._gpu_load = [0]
|
| 124 |
+
else:
|
| 125 |
+
for gid in self.device_ids:
|
| 126 |
+
p = load_object(
|
| 127 |
+
config["train_pipeline"],
|
| 128 |
+
config["train_pipeline_args"],
|
| 129 |
+
network_module=config["network_module"],
|
| 130 |
+
network_module_args=config["network_module_args"],
|
| 131 |
+
)
|
| 132 |
+
p.load_in_demo(self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt)
|
| 133 |
+
p.to(torch.device(f"cuda:{gid}"))
|
| 134 |
+
self.pipelines.append(p)
|
| 135 |
+
self._gpu_load = [0] * len(self.pipelines)
|
| 136 |
+
|
| 137 |
+
self._loaded = True
|
| 138 |
+
|
| 139 |
+
def _acquire_pipeline(self) -> int:
|
| 140 |
+
while True:
|
| 141 |
+
with self._lock:
|
| 142 |
+
for i in range(len(self._gpu_load)):
|
| 143 |
+
if self._gpu_load[i] == 0:
|
| 144 |
+
self._gpu_load[i] = 1
|
| 145 |
+
return i
|
| 146 |
+
time.sleep(0.01)
|
| 147 |
+
|
| 148 |
+
def _release_pipeline(self, idx: int):
|
| 149 |
+
with self._lock:
|
| 150 |
+
self._gpu_load[idx] = 0
|
| 151 |
+
|
| 152 |
+
def test_dit_inference(self, duration: float = 2.0, seed: int = 42) -> bool:
|
| 153 |
+
"""
|
| 154 |
+
Test DiT model inference with unconditional/blank input.
|
| 155 |
+
This method is used to verify the DiT model works before loading text encoder.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
duration: Duration of the test motion in seconds
|
| 159 |
+
seed: Random seed for reproducibility
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
True if inference succeeds and produces valid output
|
| 163 |
+
"""
|
| 164 |
+
if not self.pipelines:
|
| 165 |
+
raise RuntimeError("No pipeline loaded. Call load() first.")
|
| 166 |
+
|
| 167 |
+
pi = self._acquire_pipeline()
|
| 168 |
+
try:
|
| 169 |
+
pipeline = self.pipelines[pi]
|
| 170 |
+
pipeline.eval()
|
| 171 |
+
device = next(pipeline.parameters()).device
|
| 172 |
+
|
| 173 |
+
# Calculate frame length from duration (assuming 30fps output, 20fps internal)
|
| 174 |
+
length = int(duration * 20)
|
| 175 |
+
length = min(length, pipeline.train_frames)
|
| 176 |
+
|
| 177 |
+
# Use null features for unconditional generation
|
| 178 |
+
batch_size = 1
|
| 179 |
+
vtxt_input = pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device)
|
| 180 |
+
ctxt_input = pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device)
|
| 181 |
+
ctxt_length = torch.tensor([1] * batch_size, device=device)
|
| 182 |
+
|
| 183 |
+
# Create masks
|
| 184 |
+
from ..pipeline.motion_diffusion import length_to_mask
|
| 185 |
+
|
| 186 |
+
ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
|
| 187 |
+
x_length = torch.LongTensor([length] * batch_size).to(device)
|
| 188 |
+
x_mask_temporal = length_to_mask(x_length, pipeline.train_frames)
|
| 189 |
+
|
| 190 |
+
# Run denoising inference
|
| 191 |
+
print(f"\t>>> Running DiT inference test: length={length}, device={device}")
|
| 192 |
+
|
| 193 |
+
# Create random noise
|
| 194 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 195 |
+
latent_shape = (batch_size, pipeline.train_frames, pipeline.mean.shape[-1])
|
| 196 |
+
latents = torch.randn(latent_shape, generator=generator, device=device, dtype=vtxt_input.dtype)
|
| 197 |
+
|
| 198 |
+
# Simple single-step denoising test (just forward pass)
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
# Get timestep
|
| 201 |
+
timesteps = torch.tensor([0.5], device=device, dtype=vtxt_input.dtype).expand(batch_size)
|
| 202 |
+
|
| 203 |
+
# Forward pass through DiT
|
| 204 |
+
# Use correct parameter names for HunyuanMotionMMDiT.forward()
|
| 205 |
+
_ = pipeline.motion_transformer(
|
| 206 |
+
x=latents,
|
| 207 |
+
ctxt_input=ctxt_input,
|
| 208 |
+
vtxt_input=vtxt_input,
|
| 209 |
+
timesteps=timesteps,
|
| 210 |
+
x_mask_temporal=x_mask_temporal,
|
| 211 |
+
ctxt_mask_temporal=ctxt_mask_temporal,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
print(f"\t>>> DiT forward pass completed successfully!")
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"\t>>> DiT inference test failed: {e}")
|
| 219 |
+
raise
|
| 220 |
+
finally:
|
| 221 |
+
self._release_pipeline(pi)
|
| 222 |
+
|
| 223 |
+
def load_text_encoder(self) -> None:
|
| 224 |
+
"""
|
| 225 |
+
Load text encoder for all pipelines.
|
| 226 |
+
This is called after DiT model testing to complete the initialization.
|
| 227 |
+
"""
|
| 228 |
+
if not self.pipelines:
|
| 229 |
+
raise RuntimeError("No pipeline loaded. Call load() first.")
|
| 230 |
+
|
| 231 |
+
print(">>> Loading text encoder for all pipelines...")
|
| 232 |
+
for i, pipeline in enumerate(self.pipelines):
|
| 233 |
+
if not hasattr(pipeline, "text_encoder") or pipeline.text_encoder is None:
|
| 234 |
+
device = next(pipeline.parameters()).device
|
| 235 |
+
pipeline.text_encoder = load_object(pipeline._text_encoder_module, pipeline._text_encoder_cfg)
|
| 236 |
+
pipeline.text_encoder.to(device)
|
| 237 |
+
print(f"\t>>> Text encoder loaded for pipeline {i} on {device}")
|
| 238 |
+
|
| 239 |
+
# Update skip_text flag
|
| 240 |
+
self.skip_text = False
|
| 241 |
+
print(">>> Text encoder loading completed!")
|
| 242 |
+
|
| 243 |
+
def rewrite_text_and_infer_time(self, text: str) -> Tuple[float, str]:
|
| 244 |
+
print("Start rewriting text...")
|
| 245 |
+
duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}")
|
| 246 |
+
print(f"\t>>> Rewritten text: {rewritten_text}, duration: {duration:.2f} seconds")
|
| 247 |
+
return duration, rewritten_text
|
| 248 |
+
|
| 249 |
+
def generate_motion(
|
| 250 |
+
self,
|
| 251 |
+
text: str,
|
| 252 |
+
seeds_csv: str,
|
| 253 |
+
duration: float,
|
| 254 |
+
cfg_scale: float,
|
| 255 |
+
output_format: str = "fbx",
|
| 256 |
+
output_dir: Optional[str] = None,
|
| 257 |
+
output_filename: Optional[str] = None,
|
| 258 |
+
original_text: Optional[str] = None,
|
| 259 |
+
use_special_game_feat: bool = False,
|
| 260 |
+
) -> Tuple[Union[str, list[str]], dict]:
|
| 261 |
+
self.load()
|
| 262 |
+
seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""]
|
| 263 |
+
pi = self._acquire_pipeline()
|
| 264 |
+
try:
|
| 265 |
+
pipeline = self.pipelines[pi]
|
| 266 |
+
pipeline.eval()
|
| 267 |
+
|
| 268 |
+
# When skip_text=True (debug mode), use blank text features
|
| 269 |
+
if self.skip_text:
|
| 270 |
+
print(">>> [Debug Mode] Using blank text features (skip_text=True)")
|
| 271 |
+
device = next(pipeline.parameters()).device
|
| 272 |
+
batch_size = len(seeds) if seeds else 1
|
| 273 |
+
# Create blank hidden_state_dict using null features
|
| 274 |
+
hidden_state_dict = {
|
| 275 |
+
"text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device),
|
| 276 |
+
"text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device),
|
| 277 |
+
"text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device),
|
| 278 |
+
}
|
| 279 |
+
# Disable CFG in debug mode (use cfg_scale=1.0)
|
| 280 |
+
model_output = pipeline.generate(
|
| 281 |
+
text,
|
| 282 |
+
seeds,
|
| 283 |
+
duration,
|
| 284 |
+
cfg_scale=1.0,
|
| 285 |
+
use_special_game_feat=False,
|
| 286 |
+
hidden_state_dict=hidden_state_dict,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
model_output = pipeline.generate(
|
| 290 |
+
text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat
|
| 291 |
+
)
|
| 292 |
+
finally:
|
| 293 |
+
self._release_pipeline(pi)
|
| 294 |
+
|
| 295 |
+
ts = _now()
|
| 296 |
+
save_data, base_filename = save_visualization_data(
|
| 297 |
+
output=model_output,
|
| 298 |
+
text=text if original_text is None else original_text,
|
| 299 |
+
rewritten_text=text,
|
| 300 |
+
timestamp=ts,
|
| 301 |
+
output_dir=output_dir,
|
| 302 |
+
output_filename=output_filename,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
html_content = self._generate_html_content(
|
| 306 |
+
timestamp=ts,
|
| 307 |
+
file_path=base_filename,
|
| 308 |
+
output_dir=output_dir,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if output_format == "fbx" and not self.fbx_available:
|
| 312 |
+
print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to dict format.")
|
| 313 |
+
output_format = "dict"
|
| 314 |
+
|
| 315 |
+
if output_format == "fbx" and self.fbx_available:
|
| 316 |
+
fbx_files = self._generate_fbx_files(
|
| 317 |
+
visualization_data=save_data,
|
| 318 |
+
output_dir=output_dir,
|
| 319 |
+
fbx_filename=output_filename,
|
| 320 |
+
)
|
| 321 |
+
return html_content, fbx_files, model_output
|
| 322 |
+
elif output_format == "dict":
|
| 323 |
+
# Return HTML content and empty list for fbx_files when using dict format
|
| 324 |
+
return html_content, [], model_output
|
| 325 |
+
else:
|
| 326 |
+
raise ValueError(f">>> Invalid output format: {output_format}")
|
| 327 |
+
|
| 328 |
+
def _generate_html_content(
|
| 329 |
+
self,
|
| 330 |
+
timestamp: str,
|
| 331 |
+
file_path: str,
|
| 332 |
+
output_dir: Optional[str] = None,
|
| 333 |
+
) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Generate static HTML content with embedded data for iframe srcdoc.
|
| 336 |
+
All JavaScript code is embedded directly in the HTML, no external static resources needed.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
timestamp: Timestamp string for logging
|
| 340 |
+
file_path: Base filename (without extension)
|
| 341 |
+
output_dir: Directory where NPZ/meta files are stored
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
HTML content string (to be used in iframe srcdoc)
|
| 345 |
+
"""
|
| 346 |
+
print(f">>> Generating static HTML content, timestamp: {timestamp}")
|
| 347 |
+
gradio_dir = output_dir if output_dir is not None else "output/gradio"
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
# Generate static HTML content with embedded data (all JS is embedded in template)
|
| 351 |
+
html_content = generate_static_html_content(
|
| 352 |
+
folder_name=gradio_dir,
|
| 353 |
+
file_name=file_path,
|
| 354 |
+
hide_captions=False,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
print(f">>> Static HTML content generated for: {file_path}")
|
| 358 |
+
return html_content
|
| 359 |
+
|
| 360 |
+
except Exception as e:
|
| 361 |
+
print(f">>> Failed to generate static HTML content: {e}")
|
| 362 |
+
import traceback
|
| 363 |
+
traceback.print_exc()
|
| 364 |
+
# Return error HTML
|
| 365 |
+
return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
|
| 366 |
+
|
| 367 |
+
def _generate_fbx_files(
|
| 368 |
+
self,
|
| 369 |
+
visualization_data: dict,
|
| 370 |
+
output_dir: Optional[str] = None,
|
| 371 |
+
fbx_filename: Optional[str] = None,
|
| 372 |
+
) -> List[str]:
|
| 373 |
+
assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data"
|
| 374 |
+
fbx_files = []
|
| 375 |
+
if output_dir is None:
|
| 376 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
| 377 |
+
output_dir = os.path.join(root_dir, "output", "gradio")
|
| 378 |
+
|
| 379 |
+
smpl_data_list = visualization_data["smpl_data"]
|
| 380 |
+
|
| 381 |
+
unique_id = str(uuid.uuid4())[:8]
|
| 382 |
+
text = visualization_data["text"]
|
| 383 |
+
timestamp = visualization_data["timestamp"]
|
| 384 |
+
for bb in range(len(smpl_data_list)):
|
| 385 |
+
smpl_data = smpl_data_list[bb]
|
| 386 |
+
if fbx_filename is None:
|
| 387 |
+
fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx"
|
| 388 |
+
else:
|
| 389 |
+
fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx"
|
| 390 |
+
fbx_path = os.path.join(output_dir, fbx_filename_bb)
|
| 391 |
+
success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path)
|
| 392 |
+
if success:
|
| 393 |
+
fbx_files.append(fbx_path)
|
| 394 |
+
print(f"\t>>> FBX file generated: {fbx_path}")
|
| 395 |
+
txt_path = fbx_path.replace(".fbx", ".txt")
|
| 396 |
+
with open(txt_path, "w", encoding="utf-8") as f:
|
| 397 |
+
f.write(text)
|
| 398 |
+
fbx_files.append(txt_path)
|
| 399 |
+
|
| 400 |
+
return fbx_files
|
hymotion/utils/type_converter.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_module_device(module: nn.Module) -> torch.device:
|
| 6 |
+
"""Get the device of a module.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
module (nn.Module): A module contains the parameters.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
torch.device: The device of the module.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
next(module.parameters())
|
| 16 |
+
except StopIteration:
|
| 17 |
+
raise ValueError("The input module should contain parameters.")
|
| 18 |
+
|
| 19 |
+
if next(module.parameters()).is_cuda:
|
| 20 |
+
return torch.device(next(module.parameters()).get_device())
|
| 21 |
+
|
| 22 |
+
return torch.device("cpu")
|
hymotion/utils/visualize_mesh_web.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import threading
|
| 5 |
+
from typing import Any, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
_FILE_ACCESS_LOCK = threading.Lock()
|
| 12 |
+
|
| 13 |
+
# Template directory path
|
| 14 |
+
_TEMPLATE_DIR = os.path.join(
|
| 15 |
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
| 16 |
+
"scripts", "gradio", "templates"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sanitize_filename(filename: str) -> str:
|
| 21 |
+
"""
|
| 22 |
+
Sanitize filename to prevent path traversal attacks
|
| 23 |
+
Args:
|
| 24 |
+
filename: original filename
|
| 25 |
+
Returns:
|
| 26 |
+
sanitized filename
|
| 27 |
+
"""
|
| 28 |
+
if not filename:
|
| 29 |
+
return ""
|
| 30 |
+
|
| 31 |
+
# remove all path traversal characters
|
| 32 |
+
filename = re.sub(r"\.\.(/|\\\\\\)?", "", filename)
|
| 33 |
+
filename = filename.strip("./\\")
|
| 34 |
+
|
| 35 |
+
# only allow letters, numbers, underscores, hyphens and dots
|
| 36 |
+
# dots are only allowed once in the extension
|
| 37 |
+
filename = re.sub(r"[^a-zA-Z0-9_.-]", "", filename)
|
| 38 |
+
|
| 39 |
+
# prevent multiple consecutive dots
|
| 40 |
+
while ".." in filename:
|
| 41 |
+
filename = filename.replace("..", ".")
|
| 42 |
+
|
| 43 |
+
# prevent starting with a dot (hidden file)
|
| 44 |
+
if filename.startswith("."):
|
| 45 |
+
filename = filename[1:]
|
| 46 |
+
|
| 47 |
+
# limit file name length
|
| 48 |
+
if len(filename) > 255:
|
| 49 |
+
filename = filename[:255]
|
| 50 |
+
|
| 51 |
+
return filename
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def sanitize_folder_name(folder_name: str) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Sanitize folder name to prevent path traversal attacks
|
| 57 |
+
Args:
|
| 58 |
+
folder_name: original folder name
|
| 59 |
+
Returns:
|
| 60 |
+
sanitized folder name
|
| 61 |
+
"""
|
| 62 |
+
if not folder_name:
|
| 63 |
+
return "output" # default folder
|
| 64 |
+
|
| 65 |
+
# remove all path traversal characters
|
| 66 |
+
folder_name = re.sub(r"\.\.(/|\\\\\\)?", "", folder_name)
|
| 67 |
+
folder_name = folder_name.strip("./\\")
|
| 68 |
+
|
| 69 |
+
# only allow letters, numbers, underscores, hyphens and slashes (for subdirectories)
|
| 70 |
+
# but need to ensure slashes don't cause path traversal
|
| 71 |
+
folder_name = re.sub(r"[^a-zA-Z0-9_./-]", "", folder_name)
|
| 72 |
+
|
| 73 |
+
# split path and clean each part
|
| 74 |
+
parts = folder_name.split("/")
|
| 75 |
+
cleaned_parts = []
|
| 76 |
+
for part in parts:
|
| 77 |
+
if part and part not in [".", ".."]:
|
| 78 |
+
# clean each part
|
| 79 |
+
part = re.sub(r"[^a-zA-Z0-9_-]", "", part)
|
| 80 |
+
if part:
|
| 81 |
+
cleaned_parts.append(part)
|
| 82 |
+
|
| 83 |
+
# recombine, allow at most 3 levels of directory depth
|
| 84 |
+
if len(cleaned_parts) > 3:
|
| 85 |
+
cleaned_parts = cleaned_parts[:3]
|
| 86 |
+
|
| 87 |
+
return "/".join(cleaned_parts) if cleaned_parts else "output"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def safe_path_join(base_dir: str, *paths: str) -> str:
|
| 91 |
+
"""
|
| 92 |
+
Safe path joining, ensure the resulting path is within base_dir
|
| 93 |
+
Args:
|
| 94 |
+
base_dir: base directory
|
| 95 |
+
*paths: paths to join
|
| 96 |
+
Returns:
|
| 97 |
+
joined path
|
| 98 |
+
Raises:
|
| 99 |
+
ValueError: if path traversal is detected
|
| 100 |
+
"""
|
| 101 |
+
# clean all paths
|
| 102 |
+
cleaned_paths = []
|
| 103 |
+
for path in paths:
|
| 104 |
+
if path:
|
| 105 |
+
# clean each path part
|
| 106 |
+
path = re.sub(r"\.\.(/|\\\\\\)?", "", path)
|
| 107 |
+
path = path.strip("./\\")
|
| 108 |
+
path = re.sub(r"[^a-zA-Z0-9_.-]", "", path)
|
| 109 |
+
if path:
|
| 110 |
+
cleaned_paths.append(path)
|
| 111 |
+
|
| 112 |
+
# join paths
|
| 113 |
+
full_path = os.path.join(base_dir, *cleaned_paths)
|
| 114 |
+
|
| 115 |
+
# ensure the resulting path is within base_dir
|
| 116 |
+
base_dir = os.path.realpath(base_dir)
|
| 117 |
+
full_path = os.path.realpath(os.path.normpath(full_path))
|
| 118 |
+
|
| 119 |
+
if os.path.commonpath([base_dir, full_path]) != base_dir:
|
| 120 |
+
raise ValueError(f"Path traversal detected: {full_path} is outside {base_dir}")
|
| 121 |
+
|
| 122 |
+
return full_path
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _get_root_dir() -> str:
|
| 126 |
+
return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_output_dir(sub_path: str = "") -> str:
|
| 130 |
+
output_base = _get_root_dir()
|
| 131 |
+
if not os.path.exists(output_base):
|
| 132 |
+
os.makedirs(output_base, exist_ok=True)
|
| 133 |
+
if sub_path:
|
| 134 |
+
parts = [p for p in sub_path.replace("\\", "/").split("/") if p]
|
| 135 |
+
else:
|
| 136 |
+
parts = []
|
| 137 |
+
return safe_path_join(output_base, *parts)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def save_visualization_data(
|
| 141 |
+
output: Dict[str, Union[Tensor, list[str]]],
|
| 142 |
+
text: str,
|
| 143 |
+
rewritten_text: Union[str, list[str]],
|
| 144 |
+
timestamp: str,
|
| 145 |
+
output_dir: Optional[str] = None,
|
| 146 |
+
output_filename: Optional[str] = None,
|
| 147 |
+
):
|
| 148 |
+
from ..pipeline.body_model import construct_smpl_data_dict
|
| 149 |
+
|
| 150 |
+
if output_dir is None:
|
| 151 |
+
output_dir = get_output_dir(sub_path="output/gradio")
|
| 152 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 153 |
+
|
| 154 |
+
# for metadata
|
| 155 |
+
base_filename = output_filename if output_filename else timestamp
|
| 156 |
+
meta_path = safe_path_join(output_dir, f"{base_filename}_meta.json")
|
| 157 |
+
if isinstance(rewritten_text, str):
|
| 158 |
+
rewritten_text = [rewritten_text]
|
| 159 |
+
batch_size = output["rot6d"].shape[0]
|
| 160 |
+
meta_data = {
|
| 161 |
+
"timestamp": timestamp,
|
| 162 |
+
"text": text,
|
| 163 |
+
"text_rewrite": rewritten_text,
|
| 164 |
+
"num_samples": batch_size,
|
| 165 |
+
"base_filename": base_filename,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
with _FILE_ACCESS_LOCK:
|
| 169 |
+
with open(meta_path, "w") as f:
|
| 170 |
+
json.dump(meta_data, f, indent=2)
|
| 171 |
+
|
| 172 |
+
# for smpl data
|
| 173 |
+
rot6d = output["rot6d"]
|
| 174 |
+
transl = output["transl"]
|
| 175 |
+
|
| 176 |
+
all_smpl_data = [] # for FBX generator
|
| 177 |
+
|
| 178 |
+
for bb in range(batch_size):
|
| 179 |
+
# build data
|
| 180 |
+
smpl_data = construct_smpl_data_dict(rot6d[bb].clone(), transl[bb].clone())
|
| 181 |
+
all_smpl_data.append(smpl_data)
|
| 182 |
+
|
| 183 |
+
# prepare dictionary to save into NPZ
|
| 184 |
+
npz_dict = {}
|
| 185 |
+
npz_dict["gender"] = np.array([smpl_data.get("gender", "neutral")], dtype=str)
|
| 186 |
+
|
| 187 |
+
for key in ["Rh", "trans", "poses", "betas"]:
|
| 188 |
+
if key in smpl_data:
|
| 189 |
+
val = smpl_data[key]
|
| 190 |
+
if isinstance(val, (list, tuple)):
|
| 191 |
+
val = np.array(val)
|
| 192 |
+
elif isinstance(val, torch.Tensor):
|
| 193 |
+
val = val.cpu().numpy()
|
| 194 |
+
npz_dict[key] = val
|
| 195 |
+
|
| 196 |
+
# save single NPZ
|
| 197 |
+
sample_filename = f"{base_filename}_{bb:03d}.npz"
|
| 198 |
+
sample_path = safe_path_join(output_dir, sample_filename)
|
| 199 |
+
|
| 200 |
+
with _FILE_ACCESS_LOCK:
|
| 201 |
+
np.savez_compressed(sample_path, **npz_dict)
|
| 202 |
+
|
| 203 |
+
# construct memory dictionary to return (for compatibility)
|
| 204 |
+
memory_data = {
|
| 205 |
+
"timestamp": timestamp,
|
| 206 |
+
"text": text,
|
| 207 |
+
"text_rewrite": rewritten_text,
|
| 208 |
+
"smpl_data": all_smpl_data,
|
| 209 |
+
"meta_data": [],
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
# return base filename, subsequent logic will use this as a basis for finding _meta.json or _000.npz
|
| 213 |
+
return memory_data, base_filename
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_cached_captions(folder_name: str, file_name: str) -> List[dict]:
|
| 217 |
+
"""read _meta.json to get text"""
|
| 218 |
+
|
| 219 |
+
folder_name = sanitize_folder_name(folder_name)
|
| 220 |
+
file_name = sanitize_filename(file_name)
|
| 221 |
+
|
| 222 |
+
base_dir = get_output_dir(folder_name)
|
| 223 |
+
# try to add suffix or find
|
| 224 |
+
meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
|
| 225 |
+
|
| 226 |
+
if not os.path.exists(meta_path):
|
| 227 |
+
if "_" in file_name:
|
| 228 |
+
prefix = file_name.rsplit("_", 1)[0]
|
| 229 |
+
prefix = sanitize_filename(prefix)
|
| 230 |
+
meta_path_alt = safe_path_join(base_dir, f"{prefix}_meta.json")
|
| 231 |
+
if os.path.exists(meta_path_alt):
|
| 232 |
+
meta_path = meta_path_alt
|
| 233 |
+
else:
|
| 234 |
+
return []
|
| 235 |
+
else:
|
| 236 |
+
return []
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
with _FILE_ACCESS_LOCK:
|
| 240 |
+
with open(meta_path, "r") as f:
|
| 241 |
+
data = json.load(f)
|
| 242 |
+
|
| 243 |
+
text = data.get("text", "")
|
| 244 |
+
text_rewrite = data.get("text_rewrite", [])
|
| 245 |
+
|
| 246 |
+
captions = []
|
| 247 |
+
for i, t in enumerate(text_rewrite):
|
| 248 |
+
item = {"short caption+": f"{t}", "start_time": None, "end_time": None}
|
| 249 |
+
if text and text != t:
|
| 250 |
+
item["short caption"] = text
|
| 251 |
+
captions.append(item)
|
| 252 |
+
return captions
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error reading meta json: {e}")
|
| 255 |
+
return []
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def get_cached_smpl_frames(folder_name: str, file_name: str) -> List[list]:
|
| 259 |
+
"""
|
| 260 |
+
read logic needs to be adjusted:
|
| 261 |
+
1. if file_name is the base name, load all samples
|
| 262 |
+
2. if file_name is a specific sample name, only load that sample
|
| 263 |
+
"""
|
| 264 |
+
folder_name = sanitize_folder_name(folder_name)
|
| 265 |
+
file_name = sanitize_filename(file_name)
|
| 266 |
+
|
| 267 |
+
base_dir = get_output_dir(folder_name)
|
| 268 |
+
|
| 269 |
+
npz_direct_path = safe_path_join(base_dir, f"{file_name}.npz")
|
| 270 |
+
meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
|
| 271 |
+
|
| 272 |
+
target_indices = []
|
| 273 |
+
base_name = file_name
|
| 274 |
+
|
| 275 |
+
if os.path.isfile(npz_direct_path):
|
| 276 |
+
try:
|
| 277 |
+
if "_" in file_name:
|
| 278 |
+
prefix, suffix = file_name.rsplit("_", 1)
|
| 279 |
+
if suffix.isdigit():
|
| 280 |
+
num_samples = 1
|
| 281 |
+
base_name = prefix
|
| 282 |
+
target_indices = [int(suffix)]
|
| 283 |
+
else:
|
| 284 |
+
pass
|
| 285 |
+
else:
|
| 286 |
+
pass
|
| 287 |
+
except ValueError:
|
| 288 |
+
pass
|
| 289 |
+
if not target_indices:
|
| 290 |
+
return []
|
| 291 |
+
elif os.path.exists(meta_path):
|
| 292 |
+
try:
|
| 293 |
+
with open(meta_path, "r") as f:
|
| 294 |
+
meta = json.load(f)
|
| 295 |
+
num_samples = meta.get("num_samples", 0)
|
| 296 |
+
target_indices = range(num_samples)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
print(f"Error reading meta: {e}")
|
| 299 |
+
return []
|
| 300 |
+
else:
|
| 301 |
+
return []
|
| 302 |
+
|
| 303 |
+
all_people = []
|
| 304 |
+
|
| 305 |
+
for i in target_indices:
|
| 306 |
+
npz_path = safe_path_join(base_dir, f"{base_name}_{i:03d}.npz")
|
| 307 |
+
if not os.path.exists(npz_path):
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
with _FILE_ACCESS_LOCK:
|
| 312 |
+
with np.load(npz_path, allow_pickle=False) as data:
|
| 313 |
+
# read single person data
|
| 314 |
+
gender = str(data["gender"][0])
|
| 315 |
+
Rh = data["Rh"]
|
| 316 |
+
Th = data["trans"]
|
| 317 |
+
poses = data["poses"]
|
| 318 |
+
betas = data["betas"]
|
| 319 |
+
|
| 320 |
+
if poses.ndim == 3:
|
| 321 |
+
poses = poses.reshape(poses.shape[0], -1)
|
| 322 |
+
|
| 323 |
+
person_frames = []
|
| 324 |
+
for f in range(len(poses)):
|
| 325 |
+
frame = {
|
| 326 |
+
"id": i,
|
| 327 |
+
"gender": gender,
|
| 328 |
+
"Rh": Rh[f : f + 1].tolist(),
|
| 329 |
+
"Th": Th[f : f + 1].tolist(),
|
| 330 |
+
"poses": poses[f : f + 1].tolist(),
|
| 331 |
+
"shapes": betas.tolist(),
|
| 332 |
+
}
|
| 333 |
+
person_frames.append([frame])
|
| 334 |
+
all_people.append(person_frames)
|
| 335 |
+
except Exception as e:
|
| 336 |
+
print(f"Error loading {npz_path}: {e}")
|
| 337 |
+
|
| 338 |
+
# merge
|
| 339 |
+
combined_frames = []
|
| 340 |
+
max_frames = max(len(p) for p in all_people) if all_people else 0
|
| 341 |
+
for f_idx in range(max_frames):
|
| 342 |
+
frame_content = []
|
| 343 |
+
for person_seq in all_people:
|
| 344 |
+
if f_idx < len(person_seq):
|
| 345 |
+
frame_content.extend(person_seq[f_idx])
|
| 346 |
+
combined_frames.append(frame_content)
|
| 347 |
+
|
| 348 |
+
return combined_frames
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def generate_static_html_content(
|
| 352 |
+
folder_name: str,
|
| 353 |
+
file_name: str,
|
| 354 |
+
hide_captions: bool = False,
|
| 355 |
+
) -> str:
|
| 356 |
+
"""
|
| 357 |
+
Generate static HTML content with embedded SMPL data and captions.
|
| 358 |
+
All JavaScript code is embedded directly in the HTML template,
|
| 359 |
+
so no external static resources are needed.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
folder_name: The folder name containing the NPZ/meta files
|
| 363 |
+
file_name: The base file name (without extension)
|
| 364 |
+
hide_captions: Whether to hide captions in the visualization
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
The HTML content as a string
|
| 368 |
+
"""
|
| 369 |
+
# Load SMPL data
|
| 370 |
+
smpl_frames = get_cached_smpl_frames(folder_name, file_name)
|
| 371 |
+
if not smpl_frames:
|
| 372 |
+
raise ValueError(f"No SMPL data found for {folder_name}/{file_name}")
|
| 373 |
+
|
| 374 |
+
# Load captions
|
| 375 |
+
captions = []
|
| 376 |
+
if not hide_captions:
|
| 377 |
+
captions = get_cached_captions(folder_name, file_name)
|
| 378 |
+
|
| 379 |
+
# Generate caption HTML
|
| 380 |
+
caption_html = _generate_caption_html(captions, hide_captions)
|
| 381 |
+
|
| 382 |
+
# Convert SMPL data to JSON
|
| 383 |
+
smpl_data_json = json.dumps(smpl_frames, ensure_ascii=False)
|
| 384 |
+
|
| 385 |
+
# Load template
|
| 386 |
+
template_path = os.path.join(_TEMPLATE_DIR, "index_wooden_static.html")
|
| 387 |
+
with open(template_path, "r", encoding="utf-8") as f:
|
| 388 |
+
template_content = f.read()
|
| 389 |
+
|
| 390 |
+
# Replace placeholders with actual data
|
| 391 |
+
html_content = template_content.replace("{{ smpl_data_json }}", smpl_data_json)
|
| 392 |
+
html_content = html_content.replace("{{ caption_html }}", caption_html)
|
| 393 |
+
|
| 394 |
+
print(f">>> Generated static HTML content for {folder_name}/{file_name}")
|
| 395 |
+
return html_content
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def generate_static_html(
|
| 399 |
+
folder_name: str,
|
| 400 |
+
file_name: str,
|
| 401 |
+
output_dir: str,
|
| 402 |
+
hide_captions: bool = False,
|
| 403 |
+
) -> str:
|
| 404 |
+
"""
|
| 405 |
+
Generate a static HTML file with embedded SMPL data and captions.
|
| 406 |
+
All JavaScript code is embedded directly in the HTML template,
|
| 407 |
+
so no external static resources are needed.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
folder_name: The folder name containing the NPZ/meta files
|
| 411 |
+
file_name: The base file name (without extension)
|
| 412 |
+
output_dir: Directory to save the generated HTML file
|
| 413 |
+
hide_captions: Whether to hide captions in the visualization
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
The path to the generated HTML file
|
| 417 |
+
"""
|
| 418 |
+
html_content = generate_static_html_content(folder_name, file_name, hide_captions)
|
| 419 |
+
|
| 420 |
+
# Generate output path
|
| 421 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 422 |
+
output_html_path = os.path.join(output_dir, f"{file_name}_vis.html")
|
| 423 |
+
|
| 424 |
+
# Write HTML file
|
| 425 |
+
with _FILE_ACCESS_LOCK:
|
| 426 |
+
with open(output_html_path, "w", encoding="utf-8") as f:
|
| 427 |
+
f.write(html_content)
|
| 428 |
+
|
| 429 |
+
print(f">>> Generated static HTML: {output_html_path}")
|
| 430 |
+
return output_html_path
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _generate_caption_html(captions: List[dict], hide_captions: bool = False) -> str:
|
| 434 |
+
"""
|
| 435 |
+
Generate the caption overlay HTML.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
captions: List of caption dictionaries
|
| 439 |
+
hide_captions: Whether to hide captions
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
HTML string for caption overlay
|
| 443 |
+
"""
|
| 444 |
+
if hide_captions or not captions:
|
| 445 |
+
return ""
|
| 446 |
+
|
| 447 |
+
caption_items = []
|
| 448 |
+
for caption in captions:
|
| 449 |
+
# Get the display text (prefer rewritten text)
|
| 450 |
+
text = caption.get("short caption+") or caption.get("short caption") or "No caption"
|
| 451 |
+
caption_items.append(f'<div class="caption-item">{text}</div>')
|
| 452 |
+
|
| 453 |
+
captions_html = "\n".join(caption_items)
|
| 454 |
+
|
| 455 |
+
return f'''
|
| 456 |
+
<div class="caption-overlay">
|
| 457 |
+
<div class="motion-info" id="motion-info">
|
| 458 |
+
<div class="captions-section">
|
| 459 |
+
{captions_html}
|
| 460 |
+
</div>
|
| 461 |
+
</div>
|
| 462 |
+
</div>
|
| 463 |
+
'''
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple
|
| 2 |
+
huggingface_hub==0.30.0
|
| 3 |
+
|
| 4 |
+
torch==2.5.1
|
| 5 |
+
torchvision==0.20.1
|
| 6 |
+
accelerate==0.30.1
|
| 7 |
+
diffusers==0.26.3
|
| 8 |
+
transformers==4.53.3
|
| 9 |
+
einops==0.8.1
|
| 10 |
+
safetensors==0.5.3
|
| 11 |
+
|
| 12 |
+
numpy>=1.24.0,<2.0
|
| 13 |
+
scipy>=1.10.0
|
| 14 |
+
transforms3d==0.4.2
|
| 15 |
+
|
| 16 |
+
PyYAML==6.0
|
| 17 |
+
omegaconf==2.3.0
|
| 18 |
+
click==8.1.3
|
| 19 |
+
requests==2.32.4
|
| 20 |
+
openai==1.78.1
|
| 21 |
+
|
| 22 |
+
fbxsdkpy==2020.1.post2
|
| 23 |
+
|
| 24 |
+
torchdiffeq==0.2.5
|
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp
ADDED
|
Git LFS Details
|
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp
ADDED
|
Git LFS Details
|
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp
ADDED
|
Git LFS Details
|
scripts/gradio/static/assets/dump_wooden/faces.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:777b0806d2843c797ed18644eecc11466ff822b33b02d263c22f8ad3730e9bb5
|
| 3 |
+
size 290376
|
scripts/gradio/static/assets/dump_wooden/j_template.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
|
| 3 |
+
size 624
|
scripts/gradio/static/assets/dump_wooden/joint_names.json
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"Pelvis",
|
| 3 |
+
"L_Hip",
|
| 4 |
+
"R_Hip",
|
| 5 |
+
"Spine1",
|
| 6 |
+
"L_Knee",
|
| 7 |
+
"R_Knee",
|
| 8 |
+
"Spine2",
|
| 9 |
+
"L_Ankle",
|
| 10 |
+
"R_Ankle",
|
| 11 |
+
"Spine3",
|
| 12 |
+
"L_Foot",
|
| 13 |
+
"R_Foot",
|
| 14 |
+
"Neck",
|
| 15 |
+
"L_Collar",
|
| 16 |
+
"R_Collar",
|
| 17 |
+
"Head",
|
| 18 |
+
"L_Shoulder",
|
| 19 |
+
"R_Shoulder",
|
| 20 |
+
"L_Elbow",
|
| 21 |
+
"R_Elbow",
|
| 22 |
+
"L_Wrist",
|
| 23 |
+
"R_Wrist",
|
| 24 |
+
"L_Index1",
|
| 25 |
+
"L_Index2",
|
| 26 |
+
"L_Index3",
|
| 27 |
+
"L_Middle1",
|
| 28 |
+
"L_Middle2",
|
| 29 |
+
"L_Middle3",
|
| 30 |
+
"L_Pinky1",
|
| 31 |
+
"L_Pinky2",
|
| 32 |
+
"L_Pinky3",
|
| 33 |
+
"L_Ring1",
|
| 34 |
+
"L_Ring2",
|
| 35 |
+
"L_Ring3",
|
| 36 |
+
"L_Thumb1",
|
| 37 |
+
"L_Thumb2",
|
| 38 |
+
"L_Thumb3",
|
| 39 |
+
"R_Index1",
|
| 40 |
+
"R_Index2",
|
| 41 |
+
"R_Index3",
|
| 42 |
+
"R_Middle1",
|
| 43 |
+
"R_Middle2",
|
| 44 |
+
"R_Middle3",
|
| 45 |
+
"R_Pinky1",
|
| 46 |
+
"R_Pinky2",
|
| 47 |
+
"R_Pinky3",
|
| 48 |
+
"R_Ring1",
|
| 49 |
+
"R_Ring2",
|
| 50 |
+
"R_Ring3",
|
| 51 |
+
"R_Thumb1",
|
| 52 |
+
"R_Thumb2",
|
| 53 |
+
"R_Thumb3"
|
| 54 |
+
]
|
scripts/gradio/static/assets/dump_wooden/joints.ply
ADDED
|
Binary file (782 Bytes). View file
|
|
|
scripts/gradio/static/assets/dump_wooden/keypoints.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
|
| 3 |
+
size 624
|
scripts/gradio/static/assets/dump_wooden/kintree.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98a20fa3b53193790b63d9ac3a9c917f2f70fbe6e053dca495c75317ff4b756a
|
| 3 |
+
size 208
|
scripts/gradio/static/assets/dump_wooden/skinIndice.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:846794fb90ea01e069435ad242caefb3d5c2f913fef3255c247f48836ddc1bda
|
| 3 |
+
size 194048
|
scripts/gradio/static/assets/dump_wooden/skinWeights.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:343eac2902627ee6b45b547eb7d9f1526562eca7ba178d1dc71f9e466f307a77
|
| 3 |
+
size 388096
|
scripts/gradio/static/assets/dump_wooden/uvs.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23c92c60b609261d927990d31cbf6ace0c14cb70a5ac753a5f3927cb8c5c8191
|
| 3 |
+
size 194048
|
scripts/gradio/static/assets/dump_wooden/v_template.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9fbe1f34bfe8a07442d11166e169318022a18da8bc62ce0a9930dfdb3171050
|
| 3 |
+
size 291072
|
scripts/gradio/templates/index_wooden_static.html
ADDED
|
@@ -0,0 +1,1205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<title>Motion Visualization</title>
|
| 6 |
+
<meta charset="UTF-8">
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 8 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
| 9 |
+
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
|
| 10 |
+
<script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.10.2/dist/umd/popper.min.js"></script>
|
| 11 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.min.js"></script>
|
| 12 |
+
<style>
|
| 13 |
+
html, body {
|
| 14 |
+
background: #1a1a2e !important;
|
| 15 |
+
color: #e2e8f0;
|
| 16 |
+
margin: 0;
|
| 17 |
+
padding: 0;
|
| 18 |
+
}
|
| 19 |
+
.container {
|
| 20 |
+
padding: 0;
|
| 21 |
+
border: none;
|
| 22 |
+
background: #1a1a2e;
|
| 23 |
+
}
|
| 24 |
+
.alert-success {
|
| 25 |
+
display: none;
|
| 26 |
+
}
|
| 27 |
+
</style>
|
| 28 |
+
</head>
|
| 29 |
+
|
| 30 |
+
<body>
|
| 31 |
+
|
| 32 |
+
<!-- Fullscreen 3D container -->
|
| 33 |
+
<div class="fullscreen-container">
|
| 34 |
+
<!-- 3D viewport -->
|
| 35 |
+
<div id="vis3d"></div>
|
| 36 |
+
|
| 37 |
+
<!-- Floating caption overlay (centered at top) -->
|
| 38 |
+
{{ caption_html }}
|
| 39 |
+
|
| 40 |
+
<!-- Floating progress control panel (centered at bottom) -->
|
| 41 |
+
<div class="control-overlay">
|
| 42 |
+
<div class="control-row-minimal">
|
| 43 |
+
<div class="progress-container">
|
| 44 |
+
<input type="range" id="progressSlider" class="progress-slider-minimal" min="0" max="100" value="0">
|
| 45 |
+
</div>
|
| 46 |
+
<div class="frame-counter">
|
| 47 |
+
<span id="currentFrame">0</span> / <span id="totalFrames">0</span>
|
| 48 |
+
</div>
|
| 49 |
+
</div>
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
<!-- Loading status overlay -->
|
| 53 |
+
<div class="loading-overlay" id="loadingStatus">
|
| 54 |
+
<i class="fas fa-spinner fa-spin"></i> Loading...
|
| 55 |
+
</div>
|
| 56 |
+
|
| 57 |
+
<!-- Hidden controls for functionality -->
|
| 58 |
+
<div style="display: none;">
|
| 59 |
+
<button id="playPauseBtn"></button>
|
| 60 |
+
<button id="resetBtn"></button>
|
| 61 |
+
<input type="range" id="speedSlider" min="0.1" max="3" step="0.1" value="1">
|
| 62 |
+
<span id="speedValue">1.0x</span>
|
| 63 |
+
</div>
|
| 64 |
+
</div>
|
| 65 |
+
|
| 66 |
+
<!-- Add Font Awesome for icons -->
|
| 67 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
|
| 68 |
+
|
| 69 |
+
<script type="importmap">
|
| 70 |
+
{
|
| 71 |
+
"imports": {
|
| 72 |
+
"three": "https://cdn.jsdelivr.net/npm/three@0.160.0/build/three.module.js",
|
| 73 |
+
"three/addons/": "https://cdn.jsdelivr.net/npm/three@0.160.0/examples/jsm/"
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
</script>
|
| 77 |
+
|
| 78 |
+
<!-- Embedded SMPL Data - Generated by Python -->
|
| 79 |
+
<script type="application/json" id="smpl-data-json">
|
| 80 |
+
{{ smpl_data_json }}
|
| 81 |
+
</script>
|
| 82 |
+
|
| 83 |
+
<script type="module">
|
| 84 |
+
import * as THREE from 'three';
|
| 85 |
+
import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
|
| 86 |
+
|
| 87 |
+
// ============================================================
|
| 88 |
+
// EMBEDDED: create_ground.js functions
|
| 89 |
+
// ============================================================
|
| 90 |
+
|
| 91 |
+
function getAdaptiveGridSize(sample_data, default_size = 5) {
|
| 92 |
+
if (sample_data) {
|
| 93 |
+
const bounds = calculateDataBounds(sample_data);
|
| 94 |
+
const grid_size = Math.max(bounds.maxRange * 3, 5);
|
| 95 |
+
console.log(`Adaptive ground size: ${grid_size.toFixed(2)}, data range: ${bounds.maxRange.toFixed(2)}`);
|
| 96 |
+
return grid_size;
|
| 97 |
+
}
|
| 98 |
+
return default_size;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
function createBaseChessboard(
|
| 102 |
+
grid_size = 5,
|
| 103 |
+
divisions = 10,
|
| 104 |
+
white = "#ffffff",
|
| 105 |
+
black = "#444444",
|
| 106 |
+
texture_size = 1024,
|
| 107 |
+
sample_data = null,
|
| 108 |
+
) {
|
| 109 |
+
if (sample_data) {
|
| 110 |
+
grid_size = getAdaptiveGridSize(sample_data, grid_size);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
var adjusted_texture_size = Math.floor(texture_size / divisions) * divisions;
|
| 114 |
+
var canvas = document.createElement("canvas");
|
| 115 |
+
canvas.width = canvas.height = adjusted_texture_size;
|
| 116 |
+
var context = canvas.getContext("2d");
|
| 117 |
+
context.imageSmoothingEnabled = false;
|
| 118 |
+
|
| 119 |
+
var step = adjusted_texture_size / divisions;
|
| 120 |
+
for (var i = 0; i < divisions; i++) {
|
| 121 |
+
for (var j = 0; j < divisions; j++) {
|
| 122 |
+
context.fillStyle = (i + j) % 2 === 0 ? white : black;
|
| 123 |
+
context.fillRect(i * step, j * step, step, step);
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
var texture = new THREE.CanvasTexture(canvas);
|
| 128 |
+
texture.wrapS = THREE.RepeatWrapping;
|
| 129 |
+
texture.wrapT = THREE.RepeatWrapping;
|
| 130 |
+
texture.magFilter = THREE.NearestFilter;
|
| 131 |
+
texture.minFilter = THREE.NearestFilter;
|
| 132 |
+
texture.generateMipmaps = false;
|
| 133 |
+
|
| 134 |
+
var planeGeometry = new THREE.PlaneGeometry(grid_size, grid_size);
|
| 135 |
+
|
| 136 |
+
var planeMaterial = new THREE.MeshStandardMaterial({
|
| 137 |
+
map: texture,
|
| 138 |
+
side: THREE.DoubleSide,
|
| 139 |
+
transparent: true,
|
| 140 |
+
opacity: 0.85,
|
| 141 |
+
roughness: 0.9,
|
| 142 |
+
metalness: 0.1,
|
| 143 |
+
emissiveIntensity: 0.05,
|
| 144 |
+
});
|
| 145 |
+
|
| 146 |
+
var plane = new THREE.Mesh(planeGeometry, planeMaterial);
|
| 147 |
+
plane.receiveShadow = true;
|
| 148 |
+
|
| 149 |
+
return plane;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
function getChessboard(...args) {
|
| 153 |
+
var plane = createBaseChessboard(...args);
|
| 154 |
+
plane.rotation.x = -Math.PI;
|
| 155 |
+
return plane;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
function getChessboardXZ(...args) {
|
| 159 |
+
var plane = createBaseChessboard(...args);
|
| 160 |
+
plane.rotation.x = -Math.PI / 2;
|
| 161 |
+
return plane;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
function getCoordinate(axisLength) {
|
| 165 |
+
var axes = new THREE.Group();
|
| 166 |
+
var materialX = new THREE.LineBasicMaterial({ color: 0xff0000 });
|
| 167 |
+
var materialY = new THREE.LineBasicMaterial({ color: 0x00ff00 });
|
| 168 |
+
var materialZ = new THREE.LineBasicMaterial({ color: 0x0000ff });
|
| 169 |
+
|
| 170 |
+
var xAxisGeometry = new THREE.BufferGeometry().setFromPoints([
|
| 171 |
+
new THREE.Vector3(0, 0, 0),
|
| 172 |
+
new THREE.Vector3(axisLength, 0, 0),
|
| 173 |
+
]);
|
| 174 |
+
var yAxisGeometry = new THREE.BufferGeometry().setFromPoints([
|
| 175 |
+
new THREE.Vector3(0, 0, 0),
|
| 176 |
+
new THREE.Vector3(0, axisLength, 0),
|
| 177 |
+
]);
|
| 178 |
+
var zAxisGeometry = new THREE.BufferGeometry().setFromPoints([
|
| 179 |
+
new THREE.Vector3(0, 0, 0),
|
| 180 |
+
new THREE.Vector3(0, 0, axisLength),
|
| 181 |
+
]);
|
| 182 |
+
|
| 183 |
+
var xAxis = new THREE.Line(xAxisGeometry, materialX);
|
| 184 |
+
var yAxis = new THREE.Line(yAxisGeometry, materialY);
|
| 185 |
+
var zAxis = new THREE.Line(zAxisGeometry, materialZ);
|
| 186 |
+
|
| 187 |
+
axes.add(xAxis);
|
| 188 |
+
axes.add(yAxis);
|
| 189 |
+
axes.add(zAxis);
|
| 190 |
+
|
| 191 |
+
return axes;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
function calculateDataBounds(sample_data) {
|
| 195 |
+
let minX = Infinity, maxX = -Infinity;
|
| 196 |
+
let minY = Infinity, maxY = -Infinity;
|
| 197 |
+
let minZ = Infinity, maxZ = -Infinity;
|
| 198 |
+
|
| 199 |
+
if (sample_data && sample_data.length > 0) {
|
| 200 |
+
sample_data.forEach((frame) => {
|
| 201 |
+
if (frame.positions && Array.isArray(frame.positions)) {
|
| 202 |
+
frame.positions.forEach((pos) => {
|
| 203 |
+
let x, y, z;
|
| 204 |
+
if (typeof pos === "object") {
|
| 205 |
+
x = pos.x !== undefined ? pos.x : pos[0];
|
| 206 |
+
y = pos.y !== undefined ? pos.y : pos[1];
|
| 207 |
+
z = pos.z !== undefined ? pos.z : pos[2];
|
| 208 |
+
} else if (Array.isArray(pos)) {
|
| 209 |
+
[x, y, z] = pos;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if (x !== undefined && y !== undefined && z !== undefined) {
|
| 213 |
+
minX = Math.min(minX, x);
|
| 214 |
+
maxX = Math.max(maxX, x);
|
| 215 |
+
minY = Math.min(minY, y);
|
| 216 |
+
maxY = Math.max(maxY, y);
|
| 217 |
+
minZ = Math.min(minZ, z);
|
| 218 |
+
maxZ = Math.max(maxZ, z);
|
| 219 |
+
}
|
| 220 |
+
});
|
| 221 |
+
}
|
| 222 |
+
});
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if (minX === Infinity || maxX === -Infinity) {
|
| 226 |
+
minX = maxX = minY = maxY = minZ = maxZ = 0;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
const rangeX = Math.abs(maxX - minX);
|
| 230 |
+
const rangeY = Math.abs(maxY - minY);
|
| 231 |
+
const rangeZ = Math.abs(maxZ - minZ);
|
| 232 |
+
const maxRange = Math.max(rangeX, rangeZ);
|
| 233 |
+
|
| 234 |
+
return { minX, maxX, minY, maxY, minZ, maxZ, rangeX, rangeY, rangeZ, maxRange };
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
// ============================================================
|
| 238 |
+
// EMBEDDED: create_scene.js functions
|
| 239 |
+
// ============================================================
|
| 240 |
+
|
| 241 |
+
function create_scene(scene, camera, renderer, use_ground = true, axis_up = "z", axis_forward = "-y") {
|
| 242 |
+
const width = document.querySelector(".container") ? document.querySelector(".container").offsetWidth : window.innerWidth;
|
| 243 |
+
const height = width;
|
| 244 |
+
|
| 245 |
+
if (axis_up == "z") {
|
| 246 |
+
camera.up.set(0, 0, 1);
|
| 247 |
+
if (axis_forward == "-y") {
|
| 248 |
+
camera.position.set(0, -3, 3);
|
| 249 |
+
} else if (axis_forward == "y") {
|
| 250 |
+
camera.position.set(0, 3, 3);
|
| 251 |
+
}
|
| 252 |
+
camera.lookAt(new THREE.Vector3(0, 0, 1.5));
|
| 253 |
+
} else if (axis_up == "y") {
|
| 254 |
+
camera.up.set(0, 1, 0);
|
| 255 |
+
if (axis_forward == "z") {
|
| 256 |
+
camera.position.set(0, 2.5, 5);
|
| 257 |
+
} else if (axis_forward == "-z") {
|
| 258 |
+
camera.position.set(0, 2.5, -5);
|
| 259 |
+
}
|
| 260 |
+
camera.lookAt(new THREE.Vector3(0, 1, 0));
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
scene.background = new THREE.Color(0x000000);
|
| 264 |
+
scene.fog = new THREE.FogExp2(0x424242, 0.06);
|
| 265 |
+
|
| 266 |
+
renderer.shadowMap.enabled = true;
|
| 267 |
+
renderer.shadowMap.type = THREE.PCFSoftShadowMap;
|
| 268 |
+
|
| 269 |
+
const hemisphereLight = new THREE.HemisphereLight(0xffffff, 0x444444, 1.8);
|
| 270 |
+
hemisphereLight.position.set(0, 2, 0);
|
| 271 |
+
scene.add(hemisphereLight);
|
| 272 |
+
|
| 273 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
|
| 274 |
+
if (axis_up == "z") {
|
| 275 |
+
if (axis_forward == "-y") {
|
| 276 |
+
directionalLight.position.set(-3, 1, 5);
|
| 277 |
+
} else if (axis_forward == "y") {
|
| 278 |
+
directionalLight.position.set(3, 1, 5);
|
| 279 |
+
}
|
| 280 |
+
} else if (axis_up == "y") {
|
| 281 |
+
if (axis_forward == "z") {
|
| 282 |
+
directionalLight.position.set(3, 5, 4);
|
| 283 |
+
} else if (axis_forward == "-z") {
|
| 284 |
+
directionalLight.position.set(3, 5, -4);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
directionalLight.castShadow = true;
|
| 288 |
+
directionalLight.shadow.mapSize.width = 2048;
|
| 289 |
+
directionalLight.shadow.mapSize.height = 2048;
|
| 290 |
+
directionalLight.shadow.camera.near = 0.5;
|
| 291 |
+
directionalLight.shadow.camera.far = 50;
|
| 292 |
+
directionalLight.shadow.camera.left = -10;
|
| 293 |
+
directionalLight.shadow.camera.right = 10;
|
| 294 |
+
directionalLight.shadow.camera.top = 10;
|
| 295 |
+
directionalLight.shadow.camera.bottom = -10;
|
| 296 |
+
directionalLight.shadow.bias = -0.0001;
|
| 297 |
+
scene.add(directionalLight);
|
| 298 |
+
|
| 299 |
+
const fillLight = new THREE.DirectionalLight(0xaaccff, 0.4);
|
| 300 |
+
fillLight.position.set(-3, 3, -2);
|
| 301 |
+
scene.add(fillLight);
|
| 302 |
+
|
| 303 |
+
const rimLight = new THREE.DirectionalLight(0xffeedd, 0.3);
|
| 304 |
+
rimLight.position.set(0, 4, -5);
|
| 305 |
+
scene.add(rimLight);
|
| 306 |
+
|
| 307 |
+
if (use_ground) {
|
| 308 |
+
if (axis_up == "z") {
|
| 309 |
+
var plane = getChessboard(50, 50, '#ffffff', '#3a3a3a', 1024);
|
| 310 |
+
plane.name = 'ground';
|
| 311 |
+
plane.receiveShadow = true;
|
| 312 |
+
scene.add(plane);
|
| 313 |
+
} else if (axis_up == "y") {
|
| 314 |
+
var plane = getChessboardXZ(50, 50, '#ffffff', '#3a3a3a', 1024);
|
| 315 |
+
plane.name = 'ground';
|
| 316 |
+
plane.receiveShadow = true;
|
| 317 |
+
scene.add(plane);
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
return 0;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
function fitCameraToScene(scene, camera, controls = null, opts = {}) {
|
| 325 |
+
const { margin = 1.05, axis_up = "y", excludeNames = ["ground"] } = opts;
|
| 326 |
+
|
| 327 |
+
const box = new THREE.Box3();
|
| 328 |
+
const tmp = new THREE.Box3();
|
| 329 |
+
let has = false;
|
| 330 |
+
|
| 331 |
+
scene.traverse((obj) => {
|
| 332 |
+
if (!obj || !obj.visible) return;
|
| 333 |
+
if (obj.isLight) return;
|
| 334 |
+
const t = obj.type || "";
|
| 335 |
+
if (t.endsWith("Helper")) return;
|
| 336 |
+
if (excludeNames && excludeNames.includes(obj.name)) return;
|
| 337 |
+
|
| 338 |
+
if (obj.isMesh) {
|
| 339 |
+
if (obj.geometry && obj.geometry.type === "PlaneGeometry") return;
|
| 340 |
+
try {
|
| 341 |
+
tmp.setFromObject(obj);
|
| 342 |
+
if (!tmp.isEmpty()) {
|
| 343 |
+
if (!has) {
|
| 344 |
+
box.copy(tmp);
|
| 345 |
+
has = true;
|
| 346 |
+
} else {
|
| 347 |
+
box.union(tmp);
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
} catch (_) {}
|
| 351 |
+
}
|
| 352 |
+
});
|
| 353 |
+
|
| 354 |
+
if (!has || box.isEmpty()) return;
|
| 355 |
+
|
| 356 |
+
const sphere = new THREE.Sphere();
|
| 357 |
+
box.getBoundingSphere(sphere);
|
| 358 |
+
const center = sphere.center.clone();
|
| 359 |
+
const radius = Math.max(sphere.radius, 1e-3);
|
| 360 |
+
|
| 361 |
+
const vFov = THREE.MathUtils.degToRad(camera.fov);
|
| 362 |
+
const hFov = 2 * Math.atan(Math.tan(vFov / 2) * camera.aspect);
|
| 363 |
+
const distV = radius / Math.sin(vFov / 2);
|
| 364 |
+
const distH = radius / Math.sin(hFov / 2);
|
| 365 |
+
const dist = Math.max(distV, distH) * margin;
|
| 366 |
+
|
| 367 |
+
const elev = THREE.MathUtils.degToRad(25);
|
| 368 |
+
const azim = Math.PI / 4;
|
| 369 |
+
const horiz = Math.cos(elev);
|
| 370 |
+
let dir;
|
| 371 |
+
|
| 372 |
+
if (axis_up === "y") {
|
| 373 |
+
dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.sin(elev), Math.cos(azim) * horiz);
|
| 374 |
+
camera.up.set(0, 1, 0);
|
| 375 |
+
} else {
|
| 376 |
+
dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.cos(azim) * horiz, Math.sin(elev));
|
| 377 |
+
camera.up.set(0, 0, 1);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
camera.position.copy(center).add(dir.multiplyScalar(dist));
|
| 381 |
+
camera.updateProjectionMatrix();
|
| 382 |
+
camera.lookAt(center);
|
| 383 |
+
|
| 384 |
+
if (controls) {
|
| 385 |
+
controls.target.copy(center);
|
| 386 |
+
controls.minDistance = Math.max(radius * 0.2, 0.1);
|
| 387 |
+
controls.maxDistance = Math.max(dist * 3, controls.minDistance + 0.1);
|
| 388 |
+
controls.update();
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
// ============================================================
|
| 393 |
+
// EMBEDDED: load_wooden.js functions
|
| 394 |
+
// ============================================================
|
| 395 |
+
|
| 396 |
+
const NUM_SKIN_WEIGHTS = 4;
|
| 397 |
+
|
| 398 |
+
const SMPLH_JOINT_NAMES = [
|
| 399 |
+
"Pelvis", "L_Hip", "R_Hip", "Spine1",
|
| 400 |
+
"L_Knee", "R_Knee", "Spine2",
|
| 401 |
+
"L_Ankle", "R_Ankle", "Spine3",
|
| 402 |
+
"L_Foot", "R_Foot", "Neck", "L_Collar", "R_Collar", "Head",
|
| 403 |
+
"L_Shoulder", "R_Shoulder", "L_Elbow", "R_Elbow",
|
| 404 |
+
"L_Wrist", "R_Wrist",
|
| 405 |
+
"L_Index1", "L_Index2", "L_Index3",
|
| 406 |
+
"L_Middle1", "L_Middle2", "L_Middle3",
|
| 407 |
+
"L_Pinky1", "L_Pinky2", "L_Pinky3",
|
| 408 |
+
"L_Ring1", "L_Ring2", "L_Ring3",
|
| 409 |
+
"L_Thumb1", "L_Thumb2", "L_Thumb3",
|
| 410 |
+
"R_Index1", "R_Index2", "R_Index3",
|
| 411 |
+
"R_Middle1", "R_Middle2", "R_Middle3",
|
| 412 |
+
"R_Pinky1", "R_Pinky2", "R_Pinky3",
|
| 413 |
+
"R_Ring1", "R_Ring2", "R_Ring3",
|
| 414 |
+
"R_Thumb1", "R_Thumb2", "R_Thumb3",
|
| 415 |
+
];
|
| 416 |
+
|
| 417 |
+
const DEFAULT_EDGES = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50];
|
| 418 |
+
|
| 419 |
+
async function load_wooden(shapes, gender, basePath = '/static/assets/dump_wooden') {
|
| 420 |
+
console.log("Loading wooden model...");
|
| 421 |
+
basePath = "https://raw.githubusercontent.com/chingswy/WoodenModel/refs/heads/main/dump_wooden"
|
| 422 |
+
console.log(`Using base path: ${basePath}`);
|
| 423 |
+
|
| 424 |
+
const urls = [
|
| 425 |
+
`${basePath}/v_template.bin`,
|
| 426 |
+
`${basePath}/faces.bin`,
|
| 427 |
+
`${basePath}/skinWeights.bin`,
|
| 428 |
+
`${basePath}/skinIndice.bin`,
|
| 429 |
+
`${basePath}/j_template.bin`,
|
| 430 |
+
`${basePath}/uvs.bin`,
|
| 431 |
+
];
|
| 432 |
+
|
| 433 |
+
let edges = [...DEFAULT_EDGES];
|
| 434 |
+
try {
|
| 435 |
+
const kintreeResponse = await fetch(`${basePath}/kintree.bin`);
|
| 436 |
+
if (kintreeResponse.ok) {
|
| 437 |
+
const kintreeBuffer = await kintreeResponse.arrayBuffer();
|
| 438 |
+
edges = Array.from(new Int32Array(kintreeBuffer));
|
| 439 |
+
console.log(`Loaded kintree with ${edges.length} joints`);
|
| 440 |
+
}
|
| 441 |
+
} catch (e) {
|
| 442 |
+
console.log('Using default kintree');
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
let jointNames = [...SMPLH_JOINT_NAMES];
|
| 446 |
+
try {
|
| 447 |
+
const namesResponse = await fetch(`${basePath}/joint_names.json`);
|
| 448 |
+
if (namesResponse.ok) {
|
| 449 |
+
jointNames = await namesResponse.json();
|
| 450 |
+
console.log(`Loaded ${jointNames.length} joint names`);
|
| 451 |
+
}
|
| 452 |
+
} catch (e) {
|
| 453 |
+
console.log('Using default joint names');
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
const buffers = await Promise.all(urls.map(url => fetch(url).then(response => response.arrayBuffer())));
|
| 457 |
+
const v_template = new Float32Array(buffers[0]);
|
| 458 |
+
const faces = new Uint16Array(buffers[1]);
|
| 459 |
+
const skinWeights = new Float32Array(buffers[2]);
|
| 460 |
+
const skinIndices = new Uint16Array(buffers[3]);
|
| 461 |
+
const keypoints = new Float32Array(buffers[4]);
|
| 462 |
+
const uvs = new Float32Array(buffers[5]);
|
| 463 |
+
|
| 464 |
+
console.log(`Vertices: ${v_template.length / 3}, Faces: ${faces.length / 3}, Joints: ${keypoints.length / 3}`);
|
| 465 |
+
|
| 466 |
+
const geometry = new THREE.BufferGeometry();
|
| 467 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(v_template, 3));
|
| 468 |
+
geometry.setIndex(new THREE.BufferAttribute(faces, 1));
|
| 469 |
+
geometry.setAttribute('skinIndex', new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS));
|
| 470 |
+
geometry.setAttribute('skinWeight', new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS));
|
| 471 |
+
geometry.setAttribute('uv', new THREE.BufferAttribute(uvs, 2));
|
| 472 |
+
|
| 473 |
+
const numJoints = keypoints.length / 3;
|
| 474 |
+
|
| 475 |
+
while (edges.length < numJoints) {
|
| 476 |
+
edges.push(0);
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
var rootBone = new THREE.Bone();
|
| 480 |
+
rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]);
|
| 481 |
+
rootBone.name = jointNames[0] || 'Pelvis';
|
| 482 |
+
var bones = [rootBone];
|
| 483 |
+
|
| 484 |
+
for (let i = 1; i < numJoints; i++) {
|
| 485 |
+
const bone = new THREE.Bone();
|
| 486 |
+
const parentIndex = edges[i];
|
| 487 |
+
|
| 488 |
+
if (parentIndex >= 0 && parentIndex < i) {
|
| 489 |
+
bone.position.set(
|
| 490 |
+
keypoints[3 * i] - keypoints[3 * parentIndex],
|
| 491 |
+
keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1],
|
| 492 |
+
keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2]
|
| 493 |
+
);
|
| 494 |
+
bone.name = jointNames[i] || `Joint_${i}`;
|
| 495 |
+
bones.push(bone);
|
| 496 |
+
bones[parentIndex].add(bone);
|
| 497 |
+
console.log(`Joint ${i} (${bone.name}): parent=${parentIndex}, pos=${bone.position.toArray()}`);
|
| 498 |
+
} else {
|
| 499 |
+
console.warn(`Invalid parent index ${parentIndex} for joint ${i}, attaching to root`);
|
| 500 |
+
bone.position.set(0, 0, 0);
|
| 501 |
+
bone.name = jointNames[i] || `Joint_${i}`;
|
| 502 |
+
bones.push(bone);
|
| 503 |
+
bones[0].add(bone);
|
| 504 |
+
}
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
var skeleton = new THREE.Skeleton(bones);
|
| 508 |
+
|
| 509 |
+
geometry.computeVertexNormals();
|
| 510 |
+
|
| 511 |
+
const textureLoader = new THREE.TextureLoader();
|
| 512 |
+
|
| 513 |
+
async function loadTextureAsync(url, isSRGB = true) {
|
| 514 |
+
const tex = await textureLoader.loadAsync(url);
|
| 515 |
+
tex.flipY = false;
|
| 516 |
+
if (isSRGB) tex.colorSpace = THREE.SRGBColorSpace;
|
| 517 |
+
return tex;
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
const [baseColorMap] = await Promise.all([
|
| 521 |
+
loadTextureAsync(`${basePath}/Boy_lambert4_BaseColor.webp`, true),
|
| 522 |
+
]);
|
| 523 |
+
|
| 524 |
+
const material = new THREE.MeshStandardMaterial({
|
| 525 |
+
map: baseColorMap,
|
| 526 |
+
roughness: 0.6,
|
| 527 |
+
metalness: 0.2,
|
| 528 |
+
envMapIntensity: 1.5,
|
| 529 |
+
});
|
| 530 |
+
|
| 531 |
+
var mesh = new THREE.SkinnedMesh(geometry, material);
|
| 532 |
+
mesh.castShadow = true;
|
| 533 |
+
mesh.receiveShadow = true;
|
| 534 |
+
mesh.add(bones[0]);
|
| 535 |
+
mesh.bind(skeleton);
|
| 536 |
+
|
| 537 |
+
console.log(`Wooden model loaded: ${numJoints} joints, ${v_template.length / 3} vertices`);
|
| 538 |
+
|
| 539 |
+
return { bones, skeleton, mesh, jointNames, edges };
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
// ============================================================
|
| 543 |
+
// Main Application Code
|
| 544 |
+
// ============================================================
|
| 545 |
+
|
| 546 |
+
let scene, camera, renderer;
|
| 547 |
+
let controls;
|
| 548 |
+
let infos;
|
| 549 |
+
let currentFrame = 0;
|
| 550 |
+
let total_frame = 0;
|
| 551 |
+
const baseIntervalTime = 30;
|
| 552 |
+
var model_mesh = {};
|
| 553 |
+
|
| 554 |
+
let isPlaying = false;
|
| 555 |
+
let lastFrameTime = 0;
|
| 556 |
+
let playbackSpeed = 1.0;
|
| 557 |
+
let animationId = null;
|
| 558 |
+
let modelsLoaded = false;
|
| 559 |
+
let expectedModelCount = 0;
|
| 560 |
+
let loadedModelCount = 0;
|
| 561 |
+
|
| 562 |
+
let ignoreGlobalTrans = false;
|
| 563 |
+
let currentOffsets = [];
|
| 564 |
+
|
| 565 |
+
const updateFrame = () => {
|
| 566 |
+
if (!infos || currentFrame >= total_frame || !modelsLoaded) return;
|
| 567 |
+
|
| 568 |
+
const info = infos[currentFrame];
|
| 569 |
+
let allModelsReady = true;
|
| 570 |
+
|
| 571 |
+
info.forEach(smpl_params => {
|
| 572 |
+
if (!(smpl_params.id in model_mesh)) {
|
| 573 |
+
allModelsReady = false;
|
| 574 |
+
}
|
| 575 |
+
});
|
| 576 |
+
|
| 577 |
+
if (!allModelsReady) {
|
| 578 |
+
return;
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
const offsets = computeOffsets(info.length);
|
| 582 |
+
currentOffsets = offsets;
|
| 583 |
+
|
| 584 |
+
info.forEach((smpl_params, b) => {
|
| 585 |
+
const bones = model_mesh[smpl_params.id];
|
| 586 |
+
const meshContainer = bones[0].parent;
|
| 587 |
+
|
| 588 |
+
if (ignoreGlobalTrans) {
|
| 589 |
+
meshContainer.position.set(-offsets[b], 0, 0);
|
| 590 |
+
} else {
|
| 591 |
+
meshContainer.position.set(
|
| 592 |
+
smpl_params.Th[0][0] - offsets[b],
|
| 593 |
+
smpl_params.Th[0][1],
|
| 594 |
+
smpl_params.Th[0][2]
|
| 595 |
+
);
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
var axis = new THREE.Vector3(smpl_params.Rh[0][0], smpl_params.Rh[0][1], smpl_params.Rh[0][2]);
|
| 599 |
+
var angle = axis.length();
|
| 600 |
+
axis.normalize();
|
| 601 |
+
var quaternion = new THREE.Quaternion().setFromAxisAngle(axis, angle);
|
| 602 |
+
bones[0].quaternion.copy(quaternion);
|
| 603 |
+
|
| 604 |
+
var poses_offset = 0;
|
| 605 |
+
|
| 606 |
+
if (smpl_params.poses[0].length == 69) {
|
| 607 |
+
poses_offset = -3;
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
for (let i = 1; i < bones.length; i++) {
|
| 611 |
+
const startIndex = poses_offset + 3 * i;
|
| 612 |
+
|
| 613 |
+
if (startIndex + 2 < smpl_params.poses[0].length) {
|
| 614 |
+
var axis = new THREE.Vector3(
|
| 615 |
+
smpl_params.poses[0][startIndex],
|
| 616 |
+
smpl_params.poses[0][startIndex + 1],
|
| 617 |
+
smpl_params.poses[0][startIndex + 2]
|
| 618 |
+
);
|
| 619 |
+
var angle = axis.length();
|
| 620 |
+
|
| 621 |
+
if (angle > 1e-6) {
|
| 622 |
+
axis.normalize();
|
| 623 |
+
var quaternion = new THREE.Quaternion().setFromAxisAngle(axis, angle);
|
| 624 |
+
bones[i].quaternion.copy(quaternion);
|
| 625 |
+
} else {
|
| 626 |
+
bones[i].quaternion.set(0, 0, 0, 1);
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
}
|
| 630 |
+
});
|
| 631 |
+
|
| 632 |
+
updateUI();
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
const playLoop = (currentTime) => {
|
| 636 |
+
if (isPlaying && currentTime - lastFrameTime >= (baseIntervalTime / playbackSpeed)) {
|
| 637 |
+
currentFrame += 1;
|
| 638 |
+
if (currentFrame >= total_frame) {
|
| 639 |
+
currentFrame = 0;
|
| 640 |
+
}
|
| 641 |
+
updateFrame();
|
| 642 |
+
lastFrameTime = currentTime;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
if (isPlaying) {
|
| 646 |
+
animationId = requestAnimationFrame(playLoop);
|
| 647 |
+
}
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
const updateUI = () => {
|
| 651 |
+
document.getElementById('currentFrame').textContent = currentFrame;
|
| 652 |
+
document.getElementById('totalFrames').textContent = total_frame;
|
| 653 |
+
|
| 654 |
+
if (total_frame > 0) {
|
| 655 |
+
const progress = (currentFrame / total_frame) * 100;
|
| 656 |
+
document.getElementById('progressSlider').value = progress;
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
const updateLoadingStatus = () => {
|
| 661 |
+
const loadingElement = document.getElementById('loadingStatus');
|
| 662 |
+
if (!loadingElement) return;
|
| 663 |
+
|
| 664 |
+
if (modelsLoaded) {
|
| 665 |
+
loadingElement.innerHTML = '<i class="fas fa-check"></i> Ready';
|
| 666 |
+
loadingElement.className = 'loading-overlay complete';
|
| 667 |
+
setTimeout(() => {
|
| 668 |
+
loadingElement.className = 'loading-overlay hidden';
|
| 669 |
+
}, 1500);
|
| 670 |
+
} else {
|
| 671 |
+
loadingElement.innerHTML = `<i class="fas fa-spinner fa-spin"></i> Loading... (${loadedModelCount}/${expectedModelCount})`;
|
| 672 |
+
loadingElement.className = 'loading-overlay';
|
| 673 |
+
}
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
const updatePlayPauseButton = () => {
|
| 677 |
+
const playPauseBtn = document.getElementById('playPauseBtn');
|
| 678 |
+
if (playPauseBtn) {
|
| 679 |
+
if (isPlaying) {
|
| 680 |
+
playPauseBtn.innerHTML = '<i class="fas fa-pause"></i>';
|
| 681 |
+
playPauseBtn.title = 'Pause';
|
| 682 |
+
} else {
|
| 683 |
+
playPauseBtn.innerHTML = '<i class="fas fa-play"></i>';
|
| 684 |
+
playPauseBtn.title = 'Play';
|
| 685 |
+
}
|
| 686 |
+
}
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
const enablePlaybackControls = () => {
|
| 690 |
+
const playPauseBtn = document.getElementById('playPauseBtn');
|
| 691 |
+
const resetBtn = document.getElementById('resetBtn');
|
| 692 |
+
const progressSlider = document.getElementById('progressSlider');
|
| 693 |
+
const speedSlider = document.getElementById('speedSlider');
|
| 694 |
+
|
| 695 |
+
[playPauseBtn, resetBtn, progressSlider, speedSlider].forEach(element => {
|
| 696 |
+
if (element) {
|
| 697 |
+
element.disabled = false;
|
| 698 |
+
element.style.opacity = '1';
|
| 699 |
+
element.style.cursor = 'pointer';
|
| 700 |
+
}
|
| 701 |
+
});
|
| 702 |
+
|
| 703 |
+
updatePlayPauseButton();
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
const playAnimation = () => {
|
| 707 |
+
if (!isPlaying && total_frame > 0 && modelsLoaded) {
|
| 708 |
+
isPlaying = true;
|
| 709 |
+
lastFrameTime = performance.now();
|
| 710 |
+
animationId = requestAnimationFrame(playLoop);
|
| 711 |
+
updatePlayPauseButton();
|
| 712 |
+
}
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
const pauseAnimation = () => {
|
| 716 |
+
isPlaying = false;
|
| 717 |
+
if (animationId) {
|
| 718 |
+
cancelAnimationFrame(animationId);
|
| 719 |
+
animationId = null;
|
| 720 |
+
}
|
| 721 |
+
updatePlayPauseButton();
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
const resetAnimation = () => {
|
| 725 |
+
pauseAnimation();
|
| 726 |
+
currentFrame = 0;
|
| 727 |
+
updateFrame();
|
| 728 |
+
updatePlayPauseButton();
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
const initPlaybackControls = () => {
|
| 732 |
+
const progressSlider = document.getElementById('progressSlider');
|
| 733 |
+
|
| 734 |
+
let wasPlaying = false;
|
| 735 |
+
progressSlider.addEventListener('mousedown', () => {
|
| 736 |
+
if (!modelsLoaded) return;
|
| 737 |
+
wasPlaying = isPlaying;
|
| 738 |
+
if (isPlaying) pauseAnimation();
|
| 739 |
+
});
|
| 740 |
+
|
| 741 |
+
progressSlider.addEventListener('input', (e) => {
|
| 742 |
+
if (!modelsLoaded) return;
|
| 743 |
+
const progress = parseFloat(e.target.value);
|
| 744 |
+
currentFrame = Math.floor((progress / 100) * total_frame);
|
| 745 |
+
if (currentFrame >= total_frame) currentFrame = total_frame - 1;
|
| 746 |
+
if (currentFrame < 0) currentFrame = 0;
|
| 747 |
+
updateFrame();
|
| 748 |
+
});
|
| 749 |
+
|
| 750 |
+
progressSlider.addEventListener('mouseup', () => {
|
| 751 |
+
if (!modelsLoaded) return;
|
| 752 |
+
if (wasPlaying) playAnimation();
|
| 753 |
+
});
|
| 754 |
+
|
| 755 |
+
progressSlider.addEventListener('touchstart', () => {
|
| 756 |
+
if (!modelsLoaded) return;
|
| 757 |
+
wasPlaying = isPlaying;
|
| 758 |
+
if (isPlaying) pauseAnimation();
|
| 759 |
+
});
|
| 760 |
+
|
| 761 |
+
progressSlider.addEventListener('touchend', () => {
|
| 762 |
+
if (!modelsLoaded) return;
|
| 763 |
+
if (wasPlaying) playAnimation();
|
| 764 |
+
});
|
| 765 |
+
|
| 766 |
+
const speedSlider = document.getElementById('speedSlider');
|
| 767 |
+
const speedValue = document.getElementById('speedValue');
|
| 768 |
+
speedSlider.addEventListener('input', (e) => {
|
| 769 |
+
playbackSpeed = parseFloat(e.target.value);
|
| 770 |
+
speedValue.textContent = playbackSpeed.toFixed(1) + 'x';
|
| 771 |
+
});
|
| 772 |
+
|
| 773 |
+
document.addEventListener('keydown', (e) => {
|
| 774 |
+
if (!modelsLoaded) return;
|
| 775 |
+
switch (e.code) {
|
| 776 |
+
case 'Space':
|
| 777 |
+
e.preventDefault();
|
| 778 |
+
if (isPlaying) {
|
| 779 |
+
pauseAnimation();
|
| 780 |
+
} else {
|
| 781 |
+
playAnimation();
|
| 782 |
+
}
|
| 783 |
+
break;
|
| 784 |
+
case 'ArrowLeft':
|
| 785 |
+
e.preventDefault();
|
| 786 |
+
if (currentFrame > 0) {
|
| 787 |
+
currentFrame--;
|
| 788 |
+
updateFrame();
|
| 789 |
+
}
|
| 790 |
+
break;
|
| 791 |
+
case 'ArrowRight':
|
| 792 |
+
e.preventDefault();
|
| 793 |
+
if (currentFrame < total_frame - 1) {
|
| 794 |
+
currentFrame++;
|
| 795 |
+
updateFrame();
|
| 796 |
+
}
|
| 797 |
+
break;
|
| 798 |
+
case 'Home':
|
| 799 |
+
e.preventDefault();
|
| 800 |
+
resetAnimation();
|
| 801 |
+
break;
|
| 802 |
+
}
|
| 803 |
+
});
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
// Load embedded SMPL data directly (no fetch needed)
|
| 807 |
+
function loadEmbeddedData() {
|
| 808 |
+
try {
|
| 809 |
+
const smplDataElement = document.getElementById('smpl-data-json');
|
| 810 |
+
if (!smplDataElement) {
|
| 811 |
+
console.error('SMPL data element not found');
|
| 812 |
+
return;
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
const datas = JSON.parse(smplDataElement.textContent);
|
| 816 |
+
|
| 817 |
+
if (!datas || datas.length === 0) {
|
| 818 |
+
console.error('No SMPL data available');
|
| 819 |
+
return;
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
console.log(`Loaded ${datas.length} frames of embedded SMPL data`);
|
| 823 |
+
infos = datas;
|
| 824 |
+
total_frame = datas.length;
|
| 825 |
+
|
| 826 |
+
document.getElementById('progressSlider').max = 100;
|
| 827 |
+
updateUI();
|
| 828 |
+
updatePlayPauseButton();
|
| 829 |
+
|
| 830 |
+
expectedModelCount = infos[0].length;
|
| 831 |
+
|
| 832 |
+
loadedModelCount = 0;
|
| 833 |
+
modelsLoaded = false;
|
| 834 |
+
updateLoadingStatus();
|
| 835 |
+
|
| 836 |
+
infos[0].forEach(data => {
|
| 837 |
+
load_wooden(null, null).then(result => {
|
| 838 |
+
scene.add(result.mesh);
|
| 839 |
+
|
| 840 |
+
result.mesh.castShadow = true;
|
| 841 |
+
result.mesh.receiveShadow = true;
|
| 842 |
+
|
| 843 |
+
model_mesh[data.id] = result.bones;
|
| 844 |
+
|
| 845 |
+
loadedModelCount++;
|
| 846 |
+
|
| 847 |
+
if (loadedModelCount === expectedModelCount) {
|
| 848 |
+
modelsLoaded = true;
|
| 849 |
+
updateLoadingStatus();
|
| 850 |
+
updateFrame();
|
| 851 |
+
enablePlaybackControls();
|
| 852 |
+
fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
|
| 853 |
+
setTimeout(() => playAnimation(), 500);
|
| 854 |
+
} else {
|
| 855 |
+
updateLoadingStatus();
|
| 856 |
+
}
|
| 857 |
+
}).catch(err => {
|
| 858 |
+
console.error("Failed to load wooden model:", err);
|
| 859 |
+
});
|
| 860 |
+
});
|
| 861 |
+
|
| 862 |
+
initPlaybackControls();
|
| 863 |
+
animate();
|
| 864 |
+
} catch (error) {
|
| 865 |
+
console.error('Error loading embedded data:', error);
|
| 866 |
+
}
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
init();
|
| 870 |
+
loadEmbeddedData();
|
| 871 |
+
|
| 872 |
+
function init() {
|
| 873 |
+
const width = window.innerWidth;
|
| 874 |
+
const height = window.innerHeight;
|
| 875 |
+
scene = new THREE.Scene();
|
| 876 |
+
camera = new THREE.PerspectiveCamera(45, width / height, 0.1, 50);
|
| 877 |
+
renderer = new THREE.WebGLRenderer({ antialias: true, logarithmicDepthBuffer: true });
|
| 878 |
+
|
| 879 |
+
create_scene(scene, camera, renderer, true, 'y', 'z');
|
| 880 |
+
|
| 881 |
+
renderer.shadowMap.enabled = true;
|
| 882 |
+
renderer.shadowMap.type = THREE.PCFSoftShadowMap;
|
| 883 |
+
|
| 884 |
+
scene.background = new THREE.Color(0x424242);
|
| 885 |
+
scene.fog = new THREE.FogExp2(0x424242, 0.06);
|
| 886 |
+
|
| 887 |
+
scene.children = scene.children.filter(child => !child.isLight);
|
| 888 |
+
|
| 889 |
+
const hemisphereLight = new THREE.HemisphereLight(0xffffff, 0x444444, 1.2);
|
| 890 |
+
hemisphereLight.position.set(0, 2, 0);
|
| 891 |
+
scene.add(hemisphereLight);
|
| 892 |
+
|
| 893 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
|
| 894 |
+
directionalLight.position.set(3, 5, 4);
|
| 895 |
+
directionalLight.castShadow = true;
|
| 896 |
+
directionalLight.shadow.mapSize.width = 2048;
|
| 897 |
+
directionalLight.shadow.mapSize.height = 2048;
|
| 898 |
+
directionalLight.shadow.camera.near = 0.5;
|
| 899 |
+
directionalLight.shadow.camera.far = 50;
|
| 900 |
+
directionalLight.shadow.camera.left = -10;
|
| 901 |
+
directionalLight.shadow.camera.right = 10;
|
| 902 |
+
directionalLight.shadow.camera.top = 10;
|
| 903 |
+
directionalLight.shadow.camera.bottom = -10;
|
| 904 |
+
directionalLight.shadow.bias = -0.0001;
|
| 905 |
+
scene.add(directionalLight);
|
| 906 |
+
|
| 907 |
+
const fillLight = new THREE.DirectionalLight(0xaaccff, 0.5);
|
| 908 |
+
fillLight.position.set(-3, 3, -2);
|
| 909 |
+
scene.add(fillLight);
|
| 910 |
+
|
| 911 |
+
const rimLight = new THREE.DirectionalLight(0xffeedd, 0.4);
|
| 912 |
+
rimLight.position.set(0, 4, -5);
|
| 913 |
+
scene.add(rimLight);
|
| 914 |
+
|
| 915 |
+
renderer.toneMapping = THREE.ACESFilmicToneMapping;
|
| 916 |
+
renderer.toneMappingExposure = 1.0;
|
| 917 |
+
renderer.outputColorSpace = THREE.SRGBColorSpace;
|
| 918 |
+
|
| 919 |
+
renderer.setPixelRatio(window.devicePixelRatio);
|
| 920 |
+
renderer.setSize(width, height);
|
| 921 |
+
var container = document.getElementById('vis3d');
|
| 922 |
+
container.appendChild(renderer.domElement);
|
| 923 |
+
|
| 924 |
+
window.addEventListener('resize', onWindowResize);
|
| 925 |
+
|
| 926 |
+
controls = new OrbitControls(camera, renderer.domElement);
|
| 927 |
+
controls.minDistance = 1;
|
| 928 |
+
controls.maxDistance = 15;
|
| 929 |
+
controls.enableDamping = true;
|
| 930 |
+
controls.dampingFactor = 0.05;
|
| 931 |
+
controls.target.set(0, 1, 0);
|
| 932 |
+
fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
|
| 933 |
+
|
| 934 |
+
let isDragging = false;
|
| 935 |
+
let mouseDownTime = 0;
|
| 936 |
+
|
| 937 |
+
renderer.domElement.addEventListener('mousedown', () => {
|
| 938 |
+
isDragging = false;
|
| 939 |
+
mouseDownTime = Date.now();
|
| 940 |
+
});
|
| 941 |
+
|
| 942 |
+
renderer.domElement.addEventListener('mousemove', () => {
|
| 943 |
+
if (Date.now() - mouseDownTime > 150) {
|
| 944 |
+
isDragging = true;
|
| 945 |
+
}
|
| 946 |
+
});
|
| 947 |
+
|
| 948 |
+
renderer.domElement.addEventListener('mouseup', (e) => {
|
| 949 |
+
if (!isDragging && Date.now() - mouseDownTime < 300) {
|
| 950 |
+
if (modelsLoaded) {
|
| 951 |
+
isPlaying ? pauseAnimation() : playAnimation();
|
| 952 |
+
}
|
| 953 |
+
}
|
| 954 |
+
});
|
| 955 |
+
|
| 956 |
+
renderer.domElement.addEventListener('dblclick', () => {
|
| 957 |
+
if (modelsLoaded) {
|
| 958 |
+
pauseAnimation();
|
| 959 |
+
currentFrame = 0;
|
| 960 |
+
updateFrame();
|
| 961 |
+
}
|
| 962 |
+
});
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
function animate() {
|
| 966 |
+
requestAnimationFrame(animate);
|
| 967 |
+
if (controls && controls.enableDamping) {
|
| 968 |
+
controls.update();
|
| 969 |
+
}
|
| 970 |
+
renderer.render(scene, camera);
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
function onWindowResize() {
|
| 974 |
+
const width = window.innerWidth;
|
| 975 |
+
const height = window.innerHeight;
|
| 976 |
+
camera.aspect = width / height;
|
| 977 |
+
camera.updateProjectionMatrix();
|
| 978 |
+
renderer.setSize(width, height);
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
function computeOffsets(batchSize) {
|
| 982 |
+
const spacing = 2.0;
|
| 983 |
+
const total_width = (batchSize - 1) * spacing;
|
| 984 |
+
const start_x = -total_width / 2;
|
| 985 |
+
const offsets = [];
|
| 986 |
+
for (let i = 0; i < batchSize; i++) {
|
| 987 |
+
offsets.push(start_x + i * spacing);
|
| 988 |
+
}
|
| 989 |
+
return offsets;
|
| 990 |
+
}
|
| 991 |
+
|
| 992 |
+
</script>
|
| 993 |
+
|
| 994 |
+
<style>
|
| 995 |
+
/* Fullscreen dark mode base styles */
|
| 996 |
+
* {
|
| 997 |
+
margin: 0;
|
| 998 |
+
padding: 0;
|
| 999 |
+
box-sizing: border-box;
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
html, body {
|
| 1003 |
+
width: 100%;
|
| 1004 |
+
height: 100%;
|
| 1005 |
+
overflow: hidden;
|
| 1006 |
+
background: #424242 !important;
|
| 1007 |
+
color: #e2e8f0;
|
| 1008 |
+
}
|
| 1009 |
+
|
| 1010 |
+
/* Fullscreen container for 3D scene */
|
| 1011 |
+
.fullscreen-container {
|
| 1012 |
+
position: fixed;
|
| 1013 |
+
top: 0;
|
| 1014 |
+
left: 0;
|
| 1015 |
+
width: 100vw;
|
| 1016 |
+
height: 100vh;
|
| 1017 |
+
background: #424242;
|
| 1018 |
+
overflow: hidden;
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
#vis3d {
|
| 1022 |
+
position: absolute;
|
| 1023 |
+
top: 0;
|
| 1024 |
+
left: 0;
|
| 1025 |
+
width: 100%;
|
| 1026 |
+
height: 100%;
|
| 1027 |
+
background: #424242;
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
#vis3d canvas {
|
| 1031 |
+
display: block;
|
| 1032 |
+
width: 100% !important;
|
| 1033 |
+
height: 100% !important;
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
/* Floating caption overlay */
|
| 1037 |
+
.caption-overlay {
|
| 1038 |
+
position: absolute;
|
| 1039 |
+
top: 20px;
|
| 1040 |
+
left: 50%;
|
| 1041 |
+
transform: translateX(-50%);
|
| 1042 |
+
width: auto;
|
| 1043 |
+
max-width: 90%;
|
| 1044 |
+
z-index: 100;
|
| 1045 |
+
pointer-events: auto;
|
| 1046 |
+
}
|
| 1047 |
+
|
| 1048 |
+
.motion-info {
|
| 1049 |
+
background-color: rgba(45, 55, 72, 0.85);
|
| 1050 |
+
backdrop-filter: blur(10px);
|
| 1051 |
+
-webkit-backdrop-filter: blur(10px);
|
| 1052 |
+
border-radius: 20px;
|
| 1053 |
+
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.4);
|
| 1054 |
+
overflow: hidden;
|
| 1055 |
+
max-height: 40vh;
|
| 1056 |
+
overflow-y: auto;
|
| 1057 |
+
display: inline-block;
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
/* Floating progress control panel */
|
| 1061 |
+
.control-overlay {
|
| 1062 |
+
position: absolute;
|
| 1063 |
+
bottom: 30px;
|
| 1064 |
+
left: 50%;
|
| 1065 |
+
transform: translateX(-50%);
|
| 1066 |
+
width: 80%;
|
| 1067 |
+
max-width: 600px;
|
| 1068 |
+
z-index: 100;
|
| 1069 |
+
background: rgba(0, 0, 0, 0.4);
|
| 1070 |
+
backdrop-filter: blur(8px);
|
| 1071 |
+
-webkit-backdrop-filter: blur(8px);
|
| 1072 |
+
padding: 15px 20px;
|
| 1073 |
+
border-radius: 12px;
|
| 1074 |
+
}
|
| 1075 |
+
|
| 1076 |
+
.control-row-minimal {
|
| 1077 |
+
display: flex;
|
| 1078 |
+
align-items: center;
|
| 1079 |
+
gap: 20px;
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
.progress-container {
|
| 1083 |
+
flex: 1;
|
| 1084 |
+
}
|
| 1085 |
+
|
| 1086 |
+
.progress-slider-minimal {
|
| 1087 |
+
width: 100%;
|
| 1088 |
+
height: 8px;
|
| 1089 |
+
border-radius: 4px;
|
| 1090 |
+
background: rgba(255, 255, 255, 0.3);
|
| 1091 |
+
outline: none;
|
| 1092 |
+
cursor: pointer;
|
| 1093 |
+
-webkit-appearance: none;
|
| 1094 |
+
appearance: none;
|
| 1095 |
+
}
|
| 1096 |
+
|
| 1097 |
+
.progress-slider-minimal::-webkit-slider-runnable-track {
|
| 1098 |
+
width: 100%;
|
| 1099 |
+
height: 8px;
|
| 1100 |
+
border-radius: 4px;
|
| 1101 |
+
background: rgba(255, 255, 255, 0.3);
|
| 1102 |
+
}
|
| 1103 |
+
|
| 1104 |
+
.progress-slider-minimal::-webkit-slider-thumb {
|
| 1105 |
+
-webkit-appearance: none;
|
| 1106 |
+
appearance: none;
|
| 1107 |
+
width: 20px;
|
| 1108 |
+
height: 20px;
|
| 1109 |
+
border-radius: 50%;
|
| 1110 |
+
background: #4a9eff;
|
| 1111 |
+
cursor: pointer;
|
| 1112 |
+
border: 2px solid white;
|
| 1113 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
|
| 1114 |
+
margin-top: -6px;
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
.progress-slider-minimal::-moz-range-track {
|
| 1118 |
+
width: 100%;
|
| 1119 |
+
height: 8px;
|
| 1120 |
+
border-radius: 4px;
|
| 1121 |
+
background: rgba(255, 255, 255, 0.3);
|
| 1122 |
+
}
|
| 1123 |
+
|
| 1124 |
+
.progress-slider-minimal::-moz-range-thumb {
|
| 1125 |
+
width: 20px;
|
| 1126 |
+
height: 20px;
|
| 1127 |
+
border-radius: 50%;
|
| 1128 |
+
background: #4a9eff;
|
| 1129 |
+
cursor: pointer;
|
| 1130 |
+
border: 2px solid white;
|
| 1131 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
.frame-counter {
|
| 1135 |
+
font-family: 'SF Mono', 'Consolas', monospace;
|
| 1136 |
+
font-size: 14px;
|
| 1137 |
+
font-weight: 500;
|
| 1138 |
+
color: white;
|
| 1139 |
+
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
|
| 1140 |
+
white-space: nowrap;
|
| 1141 |
+
min-width: 80px;
|
| 1142 |
+
text-align: right;
|
| 1143 |
+
}
|
| 1144 |
+
|
| 1145 |
+
/* Loading overlay */
|
| 1146 |
+
.loading-overlay {
|
| 1147 |
+
position: absolute;
|
| 1148 |
+
top: 50%;
|
| 1149 |
+
left: 50%;
|
| 1150 |
+
transform: translate(-50%, -50%);
|
| 1151 |
+
background: rgba(0, 0, 0, 0.7);
|
| 1152 |
+
backdrop-filter: blur(8px);
|
| 1153 |
+
-webkit-backdrop-filter: blur(8px);
|
| 1154 |
+
color: white;
|
| 1155 |
+
padding: 15px 25px;
|
| 1156 |
+
border-radius: 10px;
|
| 1157 |
+
font-size: 14px;
|
| 1158 |
+
z-index: 200;
|
| 1159 |
+
display: flex;
|
| 1160 |
+
align-items: center;
|
| 1161 |
+
gap: 10px;
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
.loading-overlay.hidden {
|
| 1165 |
+
display: none;
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
.loading-overlay.complete {
|
| 1169 |
+
background: rgba(76, 175, 80, 0.85);
|
| 1170 |
+
}
|
| 1171 |
+
|
| 1172 |
+
/* Caption content styles */
|
| 1173 |
+
.loading {
|
| 1174 |
+
padding: 10px 18px;
|
| 1175 |
+
text-align: center;
|
| 1176 |
+
color: #a0aec0;
|
| 1177 |
+
font-style: italic;
|
| 1178 |
+
white-space: nowrap;
|
| 1179 |
+
}
|
| 1180 |
+
|
| 1181 |
+
.captions-section {
|
| 1182 |
+
padding: 12px 20px;
|
| 1183 |
+
white-space: nowrap;
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
.caption-item {
|
| 1187 |
+
background: transparent;
|
| 1188 |
+
border: none;
|
| 1189 |
+
border-radius: 0;
|
| 1190 |
+
margin-bottom: 6px;
|
| 1191 |
+
padding: 0;
|
| 1192 |
+
color: #f0f4f8;
|
| 1193 |
+
font-size: 1em;
|
| 1194 |
+
font-weight: 500;
|
| 1195 |
+
line-height: 1.5;
|
| 1196 |
+
text-align: center;
|
| 1197 |
+
}
|
| 1198 |
+
|
| 1199 |
+
.caption-item:last-child {
|
| 1200 |
+
margin-bottom: 0;
|
| 1201 |
+
}
|
| 1202 |
+
</style>
|
| 1203 |
+
|
| 1204 |
+
</body>
|
| 1205 |
+
</html>
|