chingshuai commited on
Commit
76957e3
·
1 Parent(s): 2fd136b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +26 -0
  3. README.md +91 -7
  4. assets/arch.png +3 -0
  5. assets/banner.png +3 -0
  6. assets/config_simplified.yml +37 -0
  7. assets/pipeline.png +3 -0
  8. assets/sotacomp.png +3 -0
  9. assets/teaser.png +3 -0
  10. assets/wooden_models/boy_Rigging_smplx_tex.fbx +3 -0
  11. examples/example_prompts/example_subset.json +61 -0
  12. gradio_app.py +894 -0
  13. hymotion/network/attention.py +110 -0
  14. hymotion/network/bricks.py +46 -0
  15. hymotion/network/encoders.py +121 -0
  16. hymotion/network/hymotion_mmdit.py +636 -0
  17. hymotion/network/modulate_layers.py +49 -0
  18. hymotion/network/positional_encoding.py +174 -0
  19. hymotion/network/text_encoders/model_constants.py +8 -0
  20. hymotion/network/text_encoders/text_encoder.py +293 -0
  21. hymotion/network/token_refiner.py +192 -0
  22. hymotion/pipeline/body_model.py +412 -0
  23. hymotion/pipeline/motion_diffusion.py +639 -0
  24. hymotion/prompt_engineering/model_constants.py +42 -0
  25. hymotion/prompt_engineering/prompt_rewrite.py +284 -0
  26. hymotion/utils/configs.py +344 -0
  27. hymotion/utils/geometry.py +856 -0
  28. hymotion/utils/loaders.py +184 -0
  29. hymotion/utils/misc.py +113 -0
  30. hymotion/utils/motion_process.py +63 -0
  31. hymotion/utils/path.py +168 -0
  32. hymotion/utils/smplh2woodfbx.py +626 -0
  33. hymotion/utils/t2m_runtime.py +400 -0
  34. hymotion/utils/type_converter.py +22 -0
  35. hymotion/utils/visualize_mesh_web.py +463 -0
  36. requirements.txt +24 -0
  37. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp +3 -0
  38. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp +3 -0
  39. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp +3 -0
  40. scripts/gradio/static/assets/dump_wooden/faces.bin +3 -0
  41. scripts/gradio/static/assets/dump_wooden/j_template.bin +3 -0
  42. scripts/gradio/static/assets/dump_wooden/joint_names.json +54 -0
  43. scripts/gradio/static/assets/dump_wooden/joints.ply +0 -0
  44. scripts/gradio/static/assets/dump_wooden/keypoints.bin +3 -0
  45. scripts/gradio/static/assets/dump_wooden/kintree.bin +3 -0
  46. scripts/gradio/static/assets/dump_wooden/skinIndice.bin +3 -0
  47. scripts/gradio/static/assets/dump_wooden/skinWeights.bin +3 -0
  48. scripts/gradio/static/assets/dump_wooden/uvs.bin +3 -0
  49. scripts/gradio/static/assets/dump_wooden/v_template.bin +3 -0
  50. scripts/gradio/templates/index_wooden_static.html +1205 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 3rdparty/fbxsdkpy-2020.1.post2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ assets/wooden_models/*.fbx filter=lfs diff=lfs merge=lfs -text
38
+ assets/wooden_models/boy_Rigging_smplx_tex.fbm/Boy_lambert4_BaseColor.png filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ *.webp filter=lfs diff=lfs merge=lfs -text
41
+ *.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cache_config.json
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.pyw
7
+ *.pyz
8
+ *.pywz
9
+ *.pyzw
10
+ *.pyzwz
11
+ cache
12
+
13
+ .vscode/
14
+
15
+ ckpts/*
16
+ !ckpts/README.md
17
+ assets/body_models/*
18
+ !assets/body_models/README.md
19
+ scripts/gradio/static/assets/dump_smplh
20
+ scripts/gradio/static/assets/export_wooden_to_js.py
21
+
22
+ test_*/
23
+ debug
24
+ tencent
25
+ output
26
+ assets/wooden_models/boy_Rigging_smplx_tex.fbm/*
README.md CHANGED
@@ -1,12 +1,96 @@
1
  ---
2
- title: HY Motion 1.0
3
- emoji: 📉
4
- colorFrom: red
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: HY-Motion-1.0
3
+ emoji: 💃
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
+ short_description: Text-to-3D and Image-to-3D Generation
11
  ---
12
 
13
+
14
+ <p align="center">
15
+ <img src="./assets/banner.png" alt="Banner" width="100%">
16
+ </p>
17
+
18
+ <div align="center">
19
+ <a href="https://hunyuan.tencent.com/motion" target="_blank">
20
+ <img src="https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage" height="22px" alt="Official Site">
21
+ </a>
22
+ <a href="https://huggingface.co/spaces/tencent/HY-Motion-1.0" target="_blank">
23
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Demo-276cb4.svg" height="22px" alt="HuggingFace Space">
24
+ </a>
25
+ <a href="https://huggingface.co/tencent/HY-Motion-1.0" target="_blank">
26
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg" height="22px" alt="HuggingFace Models">
27
+ </a>
28
+ <a href="https://arxiv.org/pdf/2512.23464" target="_blank">
29
+ <img src="https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv" height="22px" alt="ArXiv Report">
30
+ </a>
31
+ <a href="https://x.com/TencentHunyuan" target="_blank">
32
+ <img src="https://img.shields.io/badge/Hunyuan-black.svg?logo=x" height="22px" alt="X (Twitter)">
33
+ </a>
34
+ </div>
35
+
36
+
37
+ # HY-Motion 1.0: Scaling Flow Matching Models for 3D Motion Generation
38
+
39
+
40
+ <p align="center">
41
+ <img src="./assets/teaser.png" alt="Teaser" width="90%">
42
+ </p>
43
+
44
+
45
+ ## 🔥 News
46
+ - **Dec 30, 2025**: 🤗 We released the inference code and pretrained models of [HY-Motion 1.0](https://huggingface.co/tencent/HY-Motion-1.0). Please give it a try via our [HuggingFace Space](https://huggingface.co/spaces/tencent/HY-Motion-1.0) and our [Official Site](https://hunyuan.tencent.com/motion)!
47
+
48
+
49
+ ## **Introduction**
50
+
51
+ **HY-Motion 1.0** is a series of text-to-3D human motion generation models based on Diffusion Transformer (DiT) and Flow Matching. It allows developers to generate skeleton-based 3D character animations from simple text prompts, which can be directly integrated into various 3D animation pipelines. This model series is the first to scale DiT-based text-to-motion models to the billion-parameter level, achieving significant improvements in instruction-following capabilities and motion quality over existing open-source models.
52
+
53
+ ### Key Features
54
+ - **State-of-the-Art Performance**: Achieves state-of-the-art performance in both instruction-following capability and generated motion quality.
55
+
56
+ - **Billion-Scale Models**: We are the first to successfully scale DiT-based models to the billion-parameter level for text-to-motion generation. This results in superior instruction understanding and following capabilities, outperforming comparable open-source models.
57
+
58
+ - **Advanced Three-Stage Training**: Our models are trained using a comprehensive three-stage process:
59
+
60
+ - *Large-Scale Pre-training*: Trained on over 3,000 hours of diverse motion data to learn a broad motion prior.
61
+
62
+ - *High-Quality Fine-tuning*: Fine-tuned on 400 hours of curated, high-quality 3D motion data to enhance motion detail and smoothness.
63
+
64
+ - *Reinforcement Learning*: Utilizes Reinforcement Learning from human feedback and reward models to further refine instruction-following and motion naturalness.
65
+
66
+
67
+
68
+ <p align="center">
69
+ <img src="./assets/pipeline.png" alt="System Overview" width="100%">
70
+ </p>
71
+
72
+ <p align="center">
73
+ <img src="./assets/arch.png" alt="Architecture" width="100%">
74
+ </p>
75
+
76
+ <p align="center">
77
+ <img src="./assets/sotacomp.png" alt="ComparisonSoTA" width="100%">
78
+ </p>
79
+
80
+
81
+ ## 🔗 BibTeX
82
+
83
+ If you found this repository helpful, please cite our reports:
84
+
85
+ ```bibtex
86
+ @article{hymotion2025,
87
+ title={HY-Motion 1.0: Scaling Flow Matching Models for Text-To-Motion Generation},
88
+ author={Tencent Hunyuan 3D Digital Human Team},
89
+ journal={arXiv preprint arXiv:2512.23464},
90
+ year={2025}
91
+ }
92
+ ```
93
+
94
+ ## Acknowledgements
95
+
96
+ We would like to thank the contributors to the [FLUX](https://github.com/black-forest-labs/flux), [diffusers](https://github.com/huggingface/diffusers), [HuggingFace](https://huggingface.co), [SMPL](https://smpl.is.tue.mpg.de/)/[SMPLH](https://mano.is.tue.mpg.de/), [CLIP](https://github.com/openai/CLIP), [Qwen3](https://github.com/QwenLM/Qwen3), [PyTorch3D](https://github.com/facebookresearch/pytorch3d), [kornia](https://github.com/kornia/kornia), [transforms3d](https://github.com/matthew-brett/transforms3d), [FBX-SDK](https://www.autodesk.com/developer-network/platform-technologies/fbx-sdk-2020-0), [GVHMR](https://zju3dv.github.io/gvhmr/), and [HunyuanVideo](https://github.com/Tencent-Hunyuan/HunyuanVideo) repositories or tools, for their open research and exploration.
assets/arch.png ADDED

Git LFS Details

  • SHA256: f68e50a49e3ce61d55056ff41e2f8857650321fc9c6962c541134c9e67b49067
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
assets/banner.png ADDED

Git LFS Details

  • SHA256: bd96308664c230af9fc7c4a37b91e58d3fff115f51da9f6dab261f695e330fe8
  • Pointer size: 130 Bytes
  • Size of remote file: 37.9 kB
assets/config_simplified.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ network_module: hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT
2
+ network_module_args:
3
+ apply_rope_to_single_branch: false
4
+ ctxt_input_dim: 4096
5
+ dropout: 0.0
6
+ feat_dim: 1024
7
+ input_dim: 201
8
+ mask_mode: narrowband
9
+ mlp_ratio: 4.0
10
+ num_heads: 16
11
+ num_layers: 18
12
+ time_factor: 1000.0
13
+ vtxt_input_dim: 768
14
+ train_pipeline: hymotion/pipeline/motion_diffusion.MotionFlowMatching
15
+ train_pipeline_args:
16
+ enable_ctxt_null_feat: true
17
+ enable_special_game_feat: true
18
+ infer_noise_scheduler_cfg:
19
+ validation_steps: 50
20
+ losses_cfg:
21
+ recons:
22
+ name: SmoothL1Loss
23
+ weight: 1.0
24
+ noise_scheduler_cfg:
25
+ method: euler
26
+ output_mesh_fps: 30
27
+ random_generator_on_gpu: true
28
+ test_cfg:
29
+ mean_std_dir: ./stats/
30
+ text_guidance_scale: 5.0
31
+ text_encoder_cfg:
32
+ llm_type: qwen3
33
+ max_length_llm: 128
34
+ text_encoder_module: hymotion/network/text_encoders/text_encoder.HYTextModel
35
+ train_cfg:
36
+ cond_mask_prob: 0.1
37
+ train_frames: 360
assets/pipeline.png ADDED

Git LFS Details

  • SHA256: 7e05f7f119f999330a19bd66f3953bfce245e79303326c4855a0fabadf335aef
  • Pointer size: 131 Bytes
  • Size of remote file: 214 kB
assets/sotacomp.png ADDED

Git LFS Details

  • SHA256: 52b3ec6f6d6d19b67fc956c83bd6c4116109cd2153993ce2ca38274eb7993426
  • Pointer size: 131 Bytes
  • Size of remote file: 524 kB
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 9ca4c0fe8950edf378fb1af9179126ef465424d6aac6becdd08bdd506b6812b4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
assets/wooden_models/boy_Rigging_smplx_tex.fbx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e1a4fc5b121d5fa61a631ee22ba360ca128279d794d1ed75b2acb9486e71cc8
3
+ size 16490768
examples/example_prompts/example_subset.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_prompts_subset": [
3
+ "A person jumps upward with both legs twice.#90#001",
4
+ "A person jumps on their right leg.#90#002",
5
+ "A person climbs upward, moving up the slope.#60#003",
6
+ "A person climbs an obstacle.#60#004",
7
+ "A person walks forward.#120#005",
8
+ "A person walks forward, moving arms and legs while looking left and right.#180#006",
9
+ "A person walks unsteadily, then slowly sits down.#150#007",
10
+ "A person turns backward 180 degrees, then walks forward.#120#008",
11
+ "A person walks in a catwalk style, swinging their left arm while placing their right hand on their hip.#180#009",
12
+ "A person squats down on tiptoe#120#010",
13
+ "A person sits down on a chair.#90#011",
14
+ "A person runs forward.#60#012",
15
+ "A person jumps up.#90#013",
16
+ "A person jumps forward lightly, taking two steps.#69#014",
17
+ "A person shoots a basketball.#60#015",
18
+ "A person finishes freestyle swimming, then surfaces.#120#016",
19
+ "A person swings a golf club, hitting the ball forward.#111#017",
20
+ "A person runs forward, then kicks a soccer ball.#60#018",
21
+ "A person walks on a tightrope.#180#019",
22
+ "A person performs a yoga camel pose, extending their back and lifting their chest.#210#020",
23
+ "A person performs a sit-up, holding their head with both hands.#150#021",
24
+ "A person performs a lunge stretch, hands on hips.#150#022",
25
+ "A person performs a deadlift, lifting a barbell from the ground.#150#023",
26
+ "A person marches in place, swinging their arms forward and backward.#210#024",
27
+ "A person perform a squat, not standing up#93#025",
28
+ "A person performs a squat#93#026",
29
+ "A person performs a front arm raise, then does a squat.#93#027",
30
+ "A person performs a squat, raising both arms forward.#240#028",
31
+ "A person does a squat, balling both hands into fists, lowering into a squat, then standing up.#195#029",
32
+ "A person plays the piano.#270#030",
33
+ "A person dances bachata, executing rhythmic hip movements and footwork.#240#031",
34
+ "A person plays the drums while sitting down, with wide, crossing arm movements.#90#032",
35
+ "A person plays the drums while sitting down, with arms spreading wide and then crossing over.#90#033",
36
+ "A person dances jazz, jumping rhythmically.#240#034",
37
+ "A person practices tai chi, performing slow, controlled movements.#270#035",
38
+ "A person waves their right hand, sitting on a beach chair.#71#036",
39
+ "A person was sweeping the floor with their head down.#180#037",
40
+ "A person picks up an object from ground#117#038",
41
+ "A person picks up an object from lower ground with two hands#99#039",
42
+ "A person picks up an object from lower ground with two hands, and lifts over head#126#040",
43
+ "A person speaks, gesturing with both hands.#75#041",
44
+ "A person lies on a bed, reading a book.#180#042",
45
+ "A person bends down to pick up an object, then stands up straight.#150#043",
46
+ "A person flips the wok#61#044",
47
+ "A person rolls over while lying down.#60#045",
48
+ "A person walks forward, holding a tray at shoulder height with one hand.#93#046",
49
+ "A person stands up from the chair, then stretches the arms.#300#047",
50
+ "A person turns to evade.#61#048",
51
+ "A person collapses to the ground after being hit.#60#049",
52
+ "A person swings a sword forward.#60#050",
53
+ "A person attacks, holding a shield in the right hand and a sword in the left.#45#051",
54
+ "A person walks like a zombie, dragging their feet forward.#120#052",
55
+ "A person performs a taekwondo kick, extending their leg forcefully.#60#053",
56
+ "A person blocks with a shield.#60#054",
57
+ "A person lifts a long gun, then walks forward slowly.#90#055",
58
+ "A person stumbles, being hit.#45#056",
59
+ "A person assumes a boxing stance, then shifts weight to the right and punches with the right hand.#60#057"
60
+ ]
61
+ }
gradio_app.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import codecs as cs
3
+ import json
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import re
8
+ import textwrap
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import gradio as gr
12
+ import torch
13
+ from huggingface_hub import snapshot_download
14
+
15
+ def try_to_download_model():
16
+ repo_id = "tencent/HY-Motion-1.0"
17
+ target_folder = "HY-Motion-1.0-Lite"
18
+ print(f">>> start download ", repo_id, target_folder)
19
+ local_dir = snapshot_download(
20
+ repo_id=repo_id,
21
+ allow_patterns=f"{target_folder}/*",
22
+ local_dir="./downloaded_models"
23
+ )
24
+ final_model_path = os.path.join(local_dir, target_folder)
25
+ print(f">>> Final model path: {final_model_path}")
26
+ return final_model_path
27
+
28
+
29
+ # Import spaces for Hugging Face Zero GPU support
30
+ try:
31
+ import spaces
32
+ SPACES_AVAILABLE = True
33
+ except ImportError:
34
+ SPACES_AVAILABLE = False
35
+ # Create a dummy decorator when spaces is not available
36
+ class spaces:
37
+ @staticmethod
38
+ def GPU(func=None, duration=None):
39
+ def decorator(fn):
40
+ return fn
41
+ if func is not None:
42
+ return func
43
+ return decorator
44
+
45
+ from hymotion.utils.t2m_runtime import T2MRuntime
46
+
47
+ NUM_WORKERS = torch.cuda.device_count() if torch.cuda.is_available() else 1
48
+
49
+ # Global runtime instance for Zero GPU lazy loading
50
+ _global_runtime = None
51
+ _global_args = None
52
+
53
+
54
+ def _init_runtime_if_needed():
55
+ """Initialize runtime lazily for Zero GPU support."""
56
+ global _global_runtime, _global_args
57
+ if _global_runtime is not None:
58
+ return _global_runtime
59
+
60
+ if _global_args is None:
61
+ raise RuntimeError("Runtime args not set. Call set_runtime_args() first.")
62
+
63
+ args = _global_args
64
+ cfg = osp.join(args.model_path, "config.yml")
65
+ ckpt = osp.join(args.model_path, "latest.ckpt")
66
+
67
+ skip_model_loading = False
68
+ if not os.path.exists(ckpt):
69
+ print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
70
+ print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
71
+ skip_model_loading = True
72
+
73
+ print(">>> Initializing T2MRuntime...")
74
+ if "USE_HF_MODELS" not in os.environ:
75
+ os.environ["USE_HF_MODELS"] = "1"
76
+
77
+ skip_text = False
78
+ _global_runtime = T2MRuntime(
79
+ config_path=cfg,
80
+ ckpt_name=ckpt,
81
+ skip_text=skip_text,
82
+ device_ids=None,
83
+ prompt_engineering_host=args.prompt_engineering_host,
84
+ skip_model_loading=skip_model_loading,
85
+ )
86
+ return _global_runtime
87
+
88
+
89
+ @spaces.GPU(duration=120)
90
+ def generate_motion_on_gpu(
91
+ text: str,
92
+ seeds_csv: str,
93
+ motion_duration: float,
94
+ cfg_scale: float,
95
+ output_format: str,
96
+ original_text: str,
97
+ output_dir: str,
98
+ ) -> Tuple[str, List[str]]:
99
+ """
100
+ GPU-decorated function for motion generation.
101
+ This function will request GPU allocation on Hugging Face Zero GPU.
102
+ """
103
+ runtime = _init_runtime_if_needed()
104
+
105
+ html_content, fbx_files, _ = runtime.generate_motion(
106
+ text=text,
107
+ seeds_csv=seeds_csv,
108
+ duration=motion_duration,
109
+ cfg_scale=cfg_scale,
110
+ output_format=output_format,
111
+ original_text=original_text,
112
+ output_dir=output_dir,
113
+ )
114
+ return html_content, fbx_files
115
+
116
+
117
+ # define data sources
118
+ DATA_SOURCES = {
119
+ "example_prompts": "examples/example_prompts/example_subset.json",
120
+ }
121
+
122
+ # create interface
123
+ APP_CSS = """
124
+ :root{
125
+ --primary-start:#667eea; --primary-end:#764ba2;
126
+ --secondary-start:#4facfe; --secondary-end:#00f2fe;
127
+ --accent-start:#f093fb; --accent-end:#f5576c;
128
+ --page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%);
129
+ --card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%);
130
+ --radius:12px;
131
+ --iframe-bg:#ffffff;
132
+ }
133
+
134
+ /* Dark mode variables */
135
+ [data-theme="dark"], .dark {
136
+ --page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%);
137
+ --card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%);
138
+ --text-primary:#f7fafc;
139
+ --text-secondary:#e2e8f0;
140
+ --border-color:#4a5568;
141
+ --input-bg:#374151;
142
+ --input-border:#4a5568;
143
+ --iframe-bg:#1a1a2e;
144
+ }
145
+
146
+ /* Page and card */
147
+ .gradio-container{
148
+ background:var(--page-bg) !important;
149
+ min-height:100vh !important;
150
+ color:var(--text-primary, #333) !important;
151
+ }
152
+
153
+ .main-header{
154
+ background:transparent !important; border:none !important; box-shadow:none !important;
155
+ padding:0 !important; margin:10px 0 16px !important;
156
+ text-align:center !important;
157
+ }
158
+
159
+ .main-header h1, .main-header p, .main-header li {
160
+ color:var(--text-primary, #333) !important;
161
+ }
162
+
163
+ .left-panel,.right-panel{
164
+ background:var(--card-bg) !important;
165
+ border:1px solid var(--border-color, #e9ecef) !important;
166
+ border-radius:15px !important;
167
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
168
+ padding:24px !important;
169
+ }
170
+
171
+ .gradio-accordion{
172
+ border:1px solid var(--border-color, #e1e5e9) !important;
173
+ border-radius:var(--radius) !important;
174
+ margin:12px 0 !important; background:transparent !important;
175
+ }
176
+
177
+ .gradio-accordion summary{
178
+ background:transparent !important;
179
+ padding:14px 18px !important;
180
+ font-weight:600 !important;
181
+ color:var(--text-primary, #495057) !important;
182
+ }
183
+
184
+ .gradio-group{
185
+ background:transparent !important; border:none !important;
186
+ border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important;
187
+ }
188
+
189
+ /* Input class style - dark mode adaptation */
190
+ .gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{
191
+ border-radius:8px !important;
192
+ border:2px solid var(--input-border, #e9ecef) !important;
193
+ background:var(--input-bg, #fff) !important;
194
+ color:var(--text-primary, #333) !important;
195
+ transition:.2s all !important;
196
+ }
197
+
198
+ .gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{
199
+ border-color:var(--primary-start) !important;
200
+ box-shadow:0 0 0 3px rgba(102,126,234,.1) !important;
201
+ }
202
+
203
+ .gradio-slider input[type="range"]{
204
+ background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important;
205
+ border-radius:10px !important;
206
+ }
207
+
208
+ .gradio-checkbox input[type="checkbox"]{
209
+ border-radius:4px !important;
210
+ border:2px solid var(--input-border, #e9ecef) !important;
211
+ transition:.2s all !important;
212
+ }
213
+
214
+ .gradio-checkbox input[type="checkbox"]:checked{
215
+ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important;
216
+ border-color:var(--primary-start) !important;
217
+ }
218
+
219
+ /* Label text color adaptation */
220
+ .gradio-textbox label, .gradio-dropdown label, .gradio-slider label,
221
+ .gradio-checkbox label, .gradio-html label {
222
+ color:var(--text-primary, #333) !important;
223
+ }
224
+
225
+ .gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info,
226
+ .gradio-checkbox .info {
227
+ color:var(--text-secondary, #666) !important;
228
+ }
229
+
230
+ /* Status information - dark mode adaptation */
231
+ .gradio-textbox[data-testid*="状态信息"] input{
232
+ background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important;
233
+ border:2px solid var(--input-border, #dee2e6) !important;
234
+ color:var(--text-primary, #495057) !important;
235
+ font-weight:500 !important;
236
+ }
237
+
238
+ /* Button base class and variant */
239
+ .generate-button,.rewrite-button,.dice-button{
240
+ border:none !important; color:#fff !important; font-weight:600 !important;
241
+ border-radius:8px !important; transition:.3s all !important;
242
+ box-shadow:0 4px 15px rgba(0,0,0,.12) !important;
243
+ }
244
+
245
+ .generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; }
246
+ .rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; }
247
+ .dice-button{
248
+ background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important;
249
+ height:40px !important;
250
+ }
251
+
252
+ .generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; }
253
+ .dice-button:hover{
254
+ transform:scale(1.05) !important;
255
+ box-shadow:0 4px 12px rgba(240,147,251,.28) !important;
256
+ }
257
+
258
+ .dice-container{
259
+ display:flex !important;
260
+ align-items:flex-end !important;
261
+ justify-content:center !important;
262
+ }
263
+
264
+ /* Right panel clipping overflow, avoid double scrollbars */
265
+ .right-panel{
266
+ background:var(--card-bg) !important;
267
+ border:1px solid var(--border-color, #e9ecef) !important;
268
+ border-radius:15px !important;
269
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
270
+ padding:24px !important; overflow:hidden !important;
271
+ }
272
+
273
+ /* Main content row - ensure equal heights */
274
+ .main-row {
275
+ display: flex !important;
276
+ align-items: stretch !important;
277
+ }
278
+
279
+ /* Flask area - match left panel height */
280
+ .flask-display{
281
+ padding:0 !important; margin:0 !important; border:none !important;
282
+ box-shadow:none !important; background:var(--iframe-bg) !important;
283
+ border-radius:10px !important; position:relative !important;
284
+ height:100% !important; min-height:750px !important;
285
+ display:flex !important; flex-direction:column !important;
286
+ }
287
+
288
+ .flask-display iframe{
289
+ width:100% !important; flex:1 !important; min-height:750px !important;
290
+ border:none !important; border-radius:10px !important; display:block !important;
291
+ background:var(--iframe-bg) !important;
292
+ }
293
+
294
+ /* Right panel should stretch to match left panel */
295
+ .right-panel{
296
+ background:var(--card-bg) !important;
297
+ border:1px solid var(--border-color, #e9ecef) !important;
298
+ border-radius:15px !important;
299
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
300
+ padding:24px !important; overflow:hidden !important;
301
+ display:flex !important; flex-direction:column !important;
302
+ }
303
+
304
+ /* Ensure dropdown menu is visible in dark mode */
305
+ [data-theme="dark"] .gradio-dropdown .wrap,
306
+ .dark .gradio-dropdown .wrap {
307
+ background:var(--input-bg) !important;
308
+ color:var(--text-primary) !important;
309
+ }
310
+
311
+ [data-theme="dark"] .gradio-dropdown .option,
312
+ .dark .gradio-dropdown .option {
313
+ background:var(--input-bg) !important;
314
+ color:var(--text-primary) !important;
315
+ }
316
+
317
+ [data-theme="dark"] .gradio-dropdown .option:hover,
318
+ .dark .gradio-dropdown .option:hover {
319
+ background:var(--border-color) !important;
320
+ }
321
+
322
+ .footer{
323
+ text-align:center !important;
324
+ margin-top:20px !important;
325
+ padding:10px !important;
326
+ color:var(--text-secondary, #666) !important;
327
+ }
328
+ """
329
+
330
+ HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground"
331
+
332
+ FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*"
333
+
334
+ HTML_OUTPUT_PLACEHOLDER = """
335
+ <div style='height: 750px; width: 100%; border-radius: 8px; border-color: #e5e7eb; border-style: solid; border-width: 1px; display: flex; justify-content: center; align-items: center;'>
336
+ <div style='text-align: center; font-size: 16px; color: #6b7280;'>
337
+ <p style="color: #8d8d8d;">Welcome to HY-Motion-1.0!</p>
338
+ <p style="color: #8d8d8d;">No motion visualization here yet.</p>
339
+ </div>
340
+ </div>
341
+ """
342
+
343
+
344
+ def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
345
+ """Load examples from txt file."""
346
+
347
+ def _parse_line(line: str) -> Optional[Tuple[str, float]]:
348
+ line = line.strip()
349
+ if line and not line.startswith("#"):
350
+ parts = line.split("#")
351
+ if len(parts) >= 2:
352
+ text = parts[0].strip()
353
+ duration = int(parts[1]) / example_record_fps
354
+ duration = min(duration, max_duration)
355
+ else:
356
+ text = line.strip()
357
+ duration = 5.0
358
+ return text, duration
359
+ return None
360
+
361
+ examples: List[Tuple[str, float]] = []
362
+ if os.path.exists(txt_path):
363
+ try:
364
+ if txt_path.endswith(".txt"):
365
+ with cs.open(txt_path, "r", encoding="utf-8") as f:
366
+ lines = f.readlines()
367
+ for line in lines:
368
+ result = _parse_line(line)
369
+ if result is None:
370
+ continue
371
+ text, duration = result
372
+ examples.append((text, duration))
373
+ elif txt_path.endswith(".json"):
374
+ with cs.open(txt_path, "r", encoding="utf-8") as f:
375
+ lines = json.load(f)
376
+ for key, value in lines.items():
377
+ if "_raw_chn" in key or "GENERATE_PROMPT_FORMAT" in key:
378
+ continue
379
+ for line in value:
380
+ result = _parse_line(line)
381
+ if result is None:
382
+ continue
383
+ text, duration = result
384
+ examples.append((text, duration))
385
+ print(f">>> Loaded {len(examples)} examples from {txt_path}")
386
+ except Exception as e:
387
+ print(f">>> Failed to load examples from {txt_path}: {e}")
388
+ else:
389
+ print(f">>> Examples file not found: {txt_path}")
390
+
391
+ return examples
392
+
393
+
394
+ class T2MGradioUI:
395
+ def __init__(self, runtime: T2MRuntime, args: argparse.Namespace):
396
+ self.runtime = runtime
397
+ self.args = args
398
+
399
+ # Check if rewrite is available:
400
+ # - prompt_engineering_host must be provided
401
+ # - disable_rewrite must not be set
402
+ print(f">>> args: {vars(args)}")
403
+ self.rewrite_available = (
404
+ args.prompt_engineering_host is not None
405
+ and args.prompt_engineering_host.strip() != ""
406
+ and not args.disable_rewrite
407
+ )
408
+
409
+ self.all_example_data = {}
410
+ self._init_example_data()
411
+
412
+ def _init_example_data(self):
413
+ for source_name, file_path in DATA_SOURCES.items():
414
+ examples = load_examples_from_txt(file_path)
415
+ if examples:
416
+ self.all_example_data[source_name] = examples
417
+ else:
418
+ # provide default examples as fallback
419
+ self.all_example_data[source_name] = [
420
+ ("Twist at the waist and punch across the body.", 3.0),
421
+ ("A person is running then takes big leap.", 3.0),
422
+ ("A person holds a railing and walks down a set of stairs.", 5.0),
423
+ (
424
+ "A man performs a fluid and rhythmic hip-hop style dance, incorporating body waves, arm gestures, and side steps.",
425
+ 5.0,
426
+ ),
427
+ ]
428
+ print(f">>> Loaded data sources: {list(self.all_example_data.keys())}")
429
+
430
+ def _get_header_text(self):
431
+ return HEADER_BASE_MD
432
+
433
+ def _generate_random_seeds(self):
434
+ seeds = [random.randint(0, 999) for _ in range(4)]
435
+ return ",".join(map(str, seeds))
436
+
437
+ def _prompt_engineering(
438
+ self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True
439
+ ):
440
+ if not text.strip():
441
+ return "", gr.update(interactive=False), gr.update()
442
+
443
+ call_llm = enable_rewrite or enable_duration_est
444
+ if not call_llm:
445
+ print(f"\t>>> Using original duration and original text...")
446
+ predicted_duration = duration
447
+ rewritten_text = text
448
+ else:
449
+ print(f"\t>>> Using LLM to estimate duration/rewrite text...")
450
+ try:
451
+ predicted_duration, rewritten_text = self.runtime.rewrite_text_and_infer_time(text=text)
452
+ except Exception as e:
453
+ print(f"\t>>> Text rewriting/duration prediction failed: {e}")
454
+ return (
455
+ f"❌ Text rewriting/duration prediction failed: {str(e)}",
456
+ gr.update(interactive=False),
457
+ gr.update(),
458
+ )
459
+ if not enable_rewrite:
460
+ rewritten_text = text
461
+ if not enable_duration_est:
462
+ predicted_duration = duration
463
+
464
+ return rewritten_text, gr.update(interactive=True), gr.update(value=predicted_duration)
465
+
466
+ def _generate_motion(
467
+ self,
468
+ original_text: str,
469
+ rewritten_text: str,
470
+ seed_input: str,
471
+ duration: float,
472
+ cfg_scale: float,
473
+ ) -> Tuple[str, List[str]]:
474
+ # When rewrite is not available, use original_text directly
475
+ if not self.rewrite_available:
476
+ text_to_use = original_text.strip()
477
+ if not text_to_use:
478
+ return "Error: Input text is empty, please enter text first", []
479
+ else:
480
+ text_to_use = rewritten_text.strip()
481
+ if not text_to_use:
482
+ return "Error: Rewritten text is empty, please rewrite the text first", []
483
+
484
+ try:
485
+ # Use runtime from global if available (for Zero GPU), otherwise use self.runtime
486
+ runtime = _global_runtime if _global_runtime is not None else self.runtime
487
+ fbx_ok = getattr(runtime, "fbx_available", False)
488
+ req_format = "fbx" if fbx_ok else "dict"
489
+
490
+ # Use GPU-decorated function for Zero GPU support
491
+ html_content, fbx_files = generate_motion_on_gpu(
492
+ text=text_to_use,
493
+ seeds_csv=seed_input,
494
+ motion_duration=duration,
495
+ cfg_scale=cfg_scale,
496
+ output_format=req_format,
497
+ original_text=original_text,
498
+ output_dir=self.args.output_dir,
499
+ )
500
+ # Escape HTML content for srcdoc attribute
501
+ escaped_html = html_content.replace('"', '&quot;')
502
+ # Return iframe with srcdoc - directly embed HTML content
503
+ iframe_html = f'''
504
+ <iframe
505
+ srcdoc="{escaped_html}"
506
+ width="100%"
507
+ height="750px"
508
+ style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
509
+ ></iframe>
510
+ '''
511
+ return iframe_html, fbx_files
512
+ except Exception as e:
513
+ print(f"\t>>> Motion generation failed: {e}")
514
+ return (
515
+ f"❌ Motion generation failed: {str(e)}\n\nPlease check the input parameters or try again later",
516
+ [],
517
+ )
518
+
519
+ def _get_example_choices(self):
520
+ """Get all example choices from all data sources"""
521
+ choices = ["Custom Input"]
522
+ for source_name in self.all_example_data:
523
+ example_data = self.all_example_data[source_name]
524
+ for text, _ in example_data:
525
+ display_text = f"{text[:50]}..." if len(text) > 50 else text
526
+ choices.append(display_text)
527
+ return choices
528
+
529
+ def _on_example_select(self, selected_example):
530
+ """When selecting an example, the callback function"""
531
+ if selected_example == "Custom Input":
532
+ return "", self._generate_random_seeds(), gr.update()
533
+ else:
534
+ # find the corresponding example from all data sources
535
+ for source_name in self.all_example_data:
536
+ example_data = self.all_example_data[source_name]
537
+ for text, duration in example_data:
538
+ display_text = f"{text[:50]}..." if len(text) > 50 else text
539
+ if display_text == selected_example:
540
+ return text, self._generate_random_seeds(), gr.update(value=duration)
541
+ return "", self._generate_random_seeds(), gr.update()
542
+
543
+ def build_ui(self):
544
+ with gr.Blocks(css=APP_CSS) as demo:
545
+ self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
546
+
547
+ with gr.Row():
548
+ # Left control panel
549
+ with gr.Column(scale=2, elem_classes=["left-panel"]):
550
+ # Input textbox
551
+ self.text_input = gr.Textbox(
552
+ label="📝 Input Text",
553
+ placeholder="Enter text to generate motion, support Chinese and English text input.",
554
+ )
555
+ # Rewritten textbox
556
+ self.rewritten_text = gr.Textbox(
557
+ label="✏️ Rewritten Text",
558
+ placeholder="Rewritten text will be displayed here, you can further edit",
559
+ interactive=True,
560
+ visible=False,
561
+ )
562
+ # Duration slider
563
+ self.duration_slider = gr.Slider(
564
+ minimum=0.5,
565
+ maximum=12,
566
+ value=5.0,
567
+ step=0.1,
568
+ label="⏱️ Action Duration (seconds)",
569
+ info="Feel free to adjust the action duration",
570
+ )
571
+
572
+ # Execute buttons
573
+ with gr.Row():
574
+ if self.rewrite_available:
575
+ self.rewrite_btn = gr.Button(
576
+ "🔄 Rewrite Text",
577
+ variant="secondary",
578
+ size="lg",
579
+ elem_classes=["rewrite-button"],
580
+ )
581
+ else:
582
+ # Create a hidden/disabled placeholder button
583
+ self.rewrite_btn = gr.Button(
584
+ "🔄 Rewrite Text (Unavailable)",
585
+ variant="secondary",
586
+ size="lg",
587
+ elem_classes=["rewrite-button"],
588
+ interactive=False,
589
+ visible=False,
590
+ )
591
+
592
+ self.generate_btn = gr.Button(
593
+ "🚀 Generate Motion",
594
+ variant="primary",
595
+ size="lg",
596
+ elem_classes=["generate-button"],
597
+ interactive=not self.rewrite_available, # Enable directly if rewrite not available
598
+ )
599
+
600
+ if not self.rewrite_available:
601
+ gr.Markdown(
602
+ "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
603
+ )
604
+
605
+ # Advanced settings
606
+ with gr.Accordion("🔧 Advanced Settings", open=False):
607
+ self._build_advanced_settings()
608
+
609
+ # Example selection dropdown
610
+ self.example_dropdown = gr.Dropdown(
611
+ choices=self._get_example_choices(),
612
+ value="Custom Input",
613
+ label="📚 Test Examples",
614
+ info="Select a preset example or input your own text above",
615
+ interactive=True,
616
+ )
617
+
618
+ # Status message depends on whether rewrite is available
619
+ if self.rewrite_available:
620
+ status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first"
621
+ else:
622
+ status_msg = "Enter your text and click [🚀 Generate Motion] directly."
623
+
624
+ self.status_output = gr.Textbox(
625
+ label="📊 Status Information",
626
+ value=status_msg,
627
+ )
628
+
629
+ # FBX Download section
630
+ with gr.Row(visible=False) as self.fbx_download_row:
631
+ if getattr(self.runtime, "fbx_available", False):
632
+ self.fbx_files = gr.File(
633
+ label="📦 Download FBX Files",
634
+ file_count="multiple",
635
+ interactive=False,
636
+ )
637
+ else:
638
+ self.fbx_files = gr.State([])
639
+
640
+ # Right display area
641
+ with gr.Column(scale=3):
642
+ self.output_display = gr.HTML(
643
+ value=HTML_OUTPUT_PLACEHOLDER,
644
+ show_label=False,
645
+ elem_classes=["flask-display"]
646
+ )
647
+
648
+ # Footer
649
+ gr.Markdown(FOOTER_MD, elem_classes=["footer"])
650
+
651
+ self._bind_events()
652
+ demo.load(fn=self._get_header_text, outputs=[self.header_md])
653
+ return demo
654
+
655
+ def _build_advanced_settings(self):
656
+ # Only show rewrite options if rewrite is available
657
+ if self.rewrite_available:
658
+ with gr.Group():
659
+ gr.Markdown("### 🔄 Text Rewriting Options")
660
+ with gr.Row():
661
+ self.enable_rewrite = gr.Checkbox(
662
+ label="Enable Text Rewriting",
663
+ value=True,
664
+ info="Automatically optimize text prompt to get better motion generation",
665
+ )
666
+
667
+ with gr.Group():
668
+ gr.Markdown("### ⏱️ Duration Settings")
669
+ self.enable_duration_est = gr.Checkbox(
670
+ label="Enable Duration Estimation",
671
+ value=True,
672
+ info="Automatically estimate the duration of the motion",
673
+ )
674
+ else:
675
+ # Create hidden placeholders with default values (disabled)
676
+ self.enable_rewrite = gr.Checkbox(
677
+ label="Enable Text Rewriting",
678
+ value=False,
679
+ visible=False,
680
+ )
681
+ self.enable_duration_est = gr.Checkbox(
682
+ label="Enable Duration Estimation",
683
+ value=False,
684
+ visible=False,
685
+ )
686
+ with gr.Group():
687
+ gr.Markdown("### ⚠️ Prompt Engineering Unavailable")
688
+ gr.Markdown(
689
+ "Text rewriting and duration estimation are not available. "
690
+ "Your input text and duration will be used directly."
691
+ )
692
+
693
+ with gr.Group():
694
+ gr.Markdown("### ⚙️ Generation Parameters")
695
+ with gr.Row():
696
+ with gr.Column(scale=3):
697
+ self.seed_input = gr.Textbox(
698
+ label="🎯 Random Seed List (comma separated)",
699
+ value="0,1,2,3",
700
+ placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)",
701
+ info="Random seeds control the diversity of generated motions",
702
+ )
703
+ with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]):
704
+ self.dice_btn = gr.Button(
705
+ "🎲 Lucky Button",
706
+ variant="secondary",
707
+ size="sm",
708
+ elem_classes=["dice-button"],
709
+ )
710
+
711
+ self.cfg_slider = gr.Slider(
712
+ minimum=1,
713
+ maximum=10,
714
+ value=5.0,
715
+ step=0.1,
716
+ label="⚙️ CFG Strength",
717
+ info="Text fidelity: higher = more faithful to the prompt",
718
+ )
719
+
720
+ def _bind_events(self):
721
+ # Generate random seeds
722
+ self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
723
+
724
+ # Bind example selection event
725
+ self.example_dropdown.change(
726
+ fn=self._on_example_select,
727
+ inputs=[self.example_dropdown],
728
+ outputs=[self.text_input, self.seed_input, self.duration_slider],
729
+ )
730
+
731
+ # Rewrite text logic (only bind when rewrite is available)
732
+ if self.rewrite_available:
733
+ self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then(
734
+ self._prompt_engineering,
735
+ inputs=[
736
+ self.text_input,
737
+ self.duration_slider,
738
+ self.enable_rewrite,
739
+ self.enable_duration_est,
740
+ ],
741
+ outputs=[self.rewritten_text, self.generate_btn, self.duration_slider],
742
+ ).then(
743
+ fn=lambda: (
744
+ gr.update(visible=True),
745
+ "Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
746
+ ),
747
+ outputs=[self.rewritten_text, self.status_output],
748
+ )
749
+
750
+ # Generate motion logic
751
+ self.generate_btn.click(
752
+ fn=lambda: "Generating motion, please wait... (It takes some extra time to start the renderer for the first generation)",
753
+ outputs=[self.status_output],
754
+ ).then(
755
+ self._generate_motion,
756
+ inputs=[
757
+ self.text_input,
758
+ self.rewritten_text,
759
+ self.seed_input,
760
+ self.duration_slider,
761
+ self.cfg_slider,
762
+ ],
763
+ outputs=[self.output_display, self.fbx_files],
764
+ concurrency_limit=NUM_WORKERS,
765
+ ).then(
766
+ fn=lambda fbx_list: (
767
+ (
768
+ "🎉 Motion generation completed! You can view the motion visualization result on the right. FBX files are ready for download."
769
+ if fbx_list
770
+ else "🎉 Motion generation completed! You can view the motion visualization result on the right"
771
+ ),
772
+ gr.update(visible=bool(fbx_list)),
773
+ ),
774
+ inputs=[self.fbx_files],
775
+ outputs=[self.status_output, self.fbx_download_row],
776
+ )
777
+
778
+ # Reset logic - different behavior based on rewrite availability
779
+ if self.rewrite_available:
780
+ self.text_input.change(
781
+ fn=lambda: (
782
+ gr.update(visible=False),
783
+ gr.update(interactive=False),
784
+ "Please click the [🔄 Rewrite Text] button to rewrite the text first",
785
+ ),
786
+ outputs=[self.rewritten_text, self.generate_btn, self.status_output],
787
+ )
788
+ else:
789
+ # When rewrite is not available, enable generate button directly when text is entered
790
+ self.text_input.change(
791
+ fn=lambda text: (
792
+ gr.update(visible=False),
793
+ gr.update(interactive=bool(text.strip())),
794
+ (
795
+ "Ready to generate! Click [🚀 Generate Motion] to start."
796
+ if text.strip()
797
+ else "Enter your text and click [🚀 Generate Motion] directly."
798
+ ),
799
+ ),
800
+ inputs=[self.text_input],
801
+ outputs=[self.rewritten_text, self.generate_btn, self.status_output],
802
+ )
803
+ # Only bind rewritten_text change when rewrite is available
804
+ if self.rewrite_available:
805
+ self.rewritten_text.change(
806
+ fn=lambda text: (
807
+ gr.update(interactive=bool(text.strip())),
808
+ (
809
+ "Rewritten text has been modified, you can click [🚀 Generate Motion]"
810
+ if text.strip()
811
+ else "Rewritten text cannot be empty, please enter valid text"
812
+ ),
813
+ ),
814
+ inputs=[self.rewritten_text],
815
+ outputs=[self.generate_btn, self.status_output],
816
+ )
817
+
818
+
819
+ def create_demo(final_model_path):
820
+ """Create the Gradio demo with Zero GPU support."""
821
+ global _global_runtime, _global_args
822
+
823
+ class Args:
824
+ model_path = final_model_path
825
+ output_dir = "output/gradio"
826
+ prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
827
+ disable_rewrite = False
828
+
829
+ args = Args()
830
+ _global_args = args # Set global args for lazy loading
831
+
832
+ # Check required files:
833
+ cfg = osp.join(args.model_path, "config.yml")
834
+ ckpt = osp.join(args.model_path, "latest.ckpt")
835
+ if not osp.exists(cfg):
836
+ raise FileNotFoundError(f">>> Configuration file not found: {cfg}")
837
+
838
+ # Create output directory
839
+ os.makedirs(args.output_dir, exist_ok=True)
840
+
841
+ # For Zero GPU: Don't load model at startup, use lazy loading
842
+ # Create a minimal runtime for UI initialization (without model loading)
843
+ if SPACES_AVAILABLE:
844
+ print(">>> Hugging Face Spaces detected. Using Zero GPU lazy loading.")
845
+ print(">>> Model will be loaded on first GPU request.")
846
+
847
+ # Create a placeholder runtime with minimal initialization for UI
848
+ class PlaceholderRuntime:
849
+ def __init__(self):
850
+ self.fbx_available = False
851
+ self.prompt_engineering_host = args.prompt_engineering_host
852
+
853
+ def rewrite_text_and_infer_time(self, text: str):
854
+ # For prompt rewriting, we don't need GPU
855
+ from hymotion.prompt_engineering.prompt_rewrite import PromptRewriter
856
+ rewriter = PromptRewriter(host=self.prompt_engineering_host)
857
+ return rewriter.rewrite_prompt_and_infer_time(text)
858
+
859
+ runtime = PlaceholderRuntime()
860
+ else:
861
+ # Local development: load model immediately
862
+ print(">>> Local environment detected. Loading model at startup.")
863
+ skip_model_loading = False
864
+ if not os.path.exists(ckpt):
865
+ print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
866
+ print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
867
+ skip_model_loading = True
868
+
869
+ print(">>> Initializing T2MRuntime...")
870
+ if "USE_HF_MODELS" not in os.environ:
871
+ os.environ["USE_HF_MODELS"] = "1"
872
+
873
+ skip_text = False
874
+ runtime = T2MRuntime(
875
+ config_path=cfg,
876
+ ckpt_name=ckpt,
877
+ skip_text=skip_text,
878
+ device_ids=None,
879
+ prompt_engineering_host=args.prompt_engineering_host,
880
+ skip_model_loading=skip_model_loading,
881
+ )
882
+ _global_runtime = runtime # Set global runtime for GPU function
883
+
884
+ ui = T2MGradioUI(runtime=runtime, args=args)
885
+ demo = ui.build_ui()
886
+ return demo
887
+
888
+
889
+ # Create demo at module level for Hugging Face Spaces
890
+ final_model_path = try_to_download_model()
891
+ demo = create_demo(final_model_path)
892
+
893
+ if __name__ == "__main__":
894
+ demo.launch()
hymotion/network/attention.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
11
+ except ImportError:
12
+ flash_attn = None
13
+ flash_attn_varlen_func = None
14
+ _flash_attn_forward = None
15
+
16
+
17
+ MEMORY_LAYOUT = {
18
+ "flash": (lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x),
19
+ "torch": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
20
+ "vanilla": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
21
+ }
22
+
23
+
24
+ def attention(
25
+ q: Tensor,
26
+ k: Tensor,
27
+ v: Tensor,
28
+ mode: str = "flash",
29
+ drop_rate: float = 0.0,
30
+ attn_mask: Optional[Tensor] = None,
31
+ causal: bool = False,
32
+ cu_seqlens_q: Optional[Tensor] = None,
33
+ cu_seqlens_kv: Optional[Tensor] = None,
34
+ max_seqlen_q: Optional[int] = None,
35
+ max_seqlen_kv: Optional[int] = None,
36
+ batch_size: int = 1,
37
+ training: bool = True,
38
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
39
+ """Perform QKV self attention.
40
+
41
+ Args:
42
+ q (Tensor): Query tensor with shape [b, s, h, d], where h is the number of heads.
43
+ k (Tensor): Key tensor with shape [b, s1, h, d]
44
+ v (Tensor): Value tensor with shape [b, s1, h, d]
45
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
46
+ drop_rate (float): Dropout rate in attention map. (default: 0)
47
+ attn_mask (Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, h, s, s1] (torch or vanilla).
48
+ (default: None)
49
+ causal (bool): Whether to use causal attention. (default: False)
50
+ cu_seqlens_q (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
51
+ used to index into q.
52
+ cu_seqlens_kv (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
53
+ used to index into kv.
54
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
55
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
56
+
57
+ Returns:
58
+ Tensor: Output tensor after self attention with shape [b, s, hd]
59
+ """
60
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
61
+ q = pre_attn_layout(q)
62
+ k = pre_attn_layout(k)
63
+ v = pre_attn_layout(v)
64
+
65
+ if mode == "torch":
66
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
67
+ attn_mask = attn_mask.to(q.dtype)
68
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
69
+ elif mode == "flash":
70
+ assert flash_attn_varlen_func is not None, "flash_attn is not installed or not supported"
71
+ x = flash_attn_varlen_func(
72
+ q,
73
+ k,
74
+ v,
75
+ cu_seqlens_q,
76
+ cu_seqlens_kv,
77
+ max_seqlen_q,
78
+ max_seqlen_kv,
79
+ )
80
+ # x with shape [(bxs), a, d]
81
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
82
+ elif mode == "vanilla":
83
+ scale_factor = 1.0 / math.sqrt(q.size(-1))
84
+ b, a, s_q, _ = q.shape
85
+ s_k = k.size(2)
86
+ attn_bias = torch.zeros(b, a, s_q, s_k, dtype=q.dtype, device=q.device)
87
+ if causal:
88
+ # Only applied to self attention
89
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
90
+ temp_mask = torch.ones(b, a, s_q, s_q, dtype=torch.bool, device=q.device).tril(diagonal=0)
91
+ attn_bias.masked_fill_(~temp_mask, float("-inf"))
92
+ attn_bias = attn_bias.to(q.dtype)
93
+ if attn_mask is not None:
94
+ if attn_mask.dtype == torch.bool:
95
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
96
+ else:
97
+ attn_bias = attn_bias + attn_mask
98
+
99
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
100
+ attn = attn + attn_bias
101
+ attn = attn.softmax(dim=-1)
102
+ attn = torch.dropout(attn, p=drop_rate, train=training)
103
+ x = attn @ v
104
+ else:
105
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
106
+
107
+ x = post_attn_layout(x)
108
+ b, s, h, d = x.shape
109
+ out = x.reshape(b, s, -1)
110
+ return out
hymotion/network/bricks.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+
8
+ def get_activation_layer(act_type: str) -> Callable[[], nn.Module]:
9
+ if act_type == "gelu":
10
+ return lambda: nn.GELU()
11
+ elif act_type == "gelu_tanh":
12
+ return lambda: nn.GELU(approximate="tanh")
13
+ elif act_type == "relu":
14
+ return nn.ReLU
15
+ elif act_type == "silu":
16
+ return nn.SiLU
17
+ else:
18
+ raise ValueError(f"Unknown activation type: {act_type}")
19
+
20
+
21
+ def get_norm_layer(norm_type: Optional[str]):
22
+ if norm_type == "layer":
23
+ return nn.LayerNorm
24
+ elif norm_type == "rms":
25
+ return RMSNorm
26
+ elif norm_type == "none" or norm_type is None:
27
+ return nn.Identity
28
+ else:
29
+ raise ValueError(f"Unknown norm type: {norm_type}")
30
+
31
+
32
+ class RMSNorm(nn.Module):
33
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.eps = eps
36
+ if elementwise_affine:
37
+ self.weight = nn.Parameter(torch.ones(dim))
38
+
39
+ def _norm(self, x: Tensor) -> Tensor:
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ output = self._norm(x.float()).type_as(x)
44
+ if hasattr(self, "weight"):
45
+ output = output * self.weight
46
+ return output
hymotion/network/encoders.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from ..utils.misc import to_2tuple
9
+ from .bricks import get_activation_layer, get_norm_layer
10
+ from .modulate_layers import ModulateDiT, modulate
11
+
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim: int,
17
+ feat_dim: int,
18
+ out_dim: Optional[int] = None,
19
+ act_type: str = "gelu",
20
+ norm_type: Optional[str] = None,
21
+ bias: bool = True,
22
+ drop: float = 0.0,
23
+ use_conv: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+ out_dim = out_dim or in_dim
27
+ feat_dim = feat_dim or in_dim
28
+ bias = to_2tuple(bias)
29
+ drop_probs = to_2tuple(drop)
30
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
31
+
32
+ self.fc1 = linear_layer(in_dim, feat_dim, bias=bias[0] if isinstance(bias, (list, tuple)) else bias)
33
+ self.act = get_activation_layer(act_type)()
34
+ self.drop1 = nn.Dropout(drop_probs[0] if isinstance(drop_probs, (list, tuple)) else drop_probs)
35
+ self.norm = get_norm_layer(norm_type)(feat_dim) if norm_type else nn.Identity()
36
+ self.fc2 = linear_layer(feat_dim, out_dim, bias=bias[1] if isinstance(bias, (list, tuple)) else bias)
37
+ self.drop2 = nn.Dropout(drop_probs[1])
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop1(x)
43
+ x = self.norm(x)
44
+ x = self.fc2(x)
45
+ x = self.drop2(x)
46
+ return x
47
+
48
+
49
+ class MLPEncoder(nn.Module):
50
+ def __init__(self, in_dim: int, feat_dim: int, num_layers: int, act_type: str = "silu") -> None:
51
+ super(MLPEncoder, self).__init__()
52
+ self.in_dim = in_dim
53
+ self.feat_dim = feat_dim
54
+ linears = []
55
+ linears.append(nn.Linear(in_features=in_dim, out_features=self.feat_dim))
56
+ for i in range(num_layers - 1):
57
+ linears.append(get_activation_layer(act_type)())
58
+ linears.append(nn.Linear(self.feat_dim, self.feat_dim))
59
+ self.linears = nn.Sequential(*linears)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return self.linears(x)
63
+
64
+
65
+ class FinalLayer(nn.Module):
66
+ def __init__(self, feat_dim: int, out_dim: int, act_type: str = "gelu", zero_init=False, **kwargs):
67
+ super().__init__()
68
+
69
+ self.norm_final = nn.LayerNorm(feat_dim, elementwise_affine=False, eps=1e-6)
70
+ self.adaLN_modulation = ModulateDiT(feat_dim, factor=2, act_type=act_type)
71
+ self.linear = nn.Linear(feat_dim, out_dim, bias=True)
72
+ if zero_init:
73
+ nn.init.zeros_(self.linear.weight)
74
+ nn.init.zeros_(self.linear.bias)
75
+
76
+ def forward(self, x: Tensor, adapter: Tensor) -> Tensor:
77
+ shift, scale = self.adaLN_modulation(adapter).chunk(2, dim=-1)
78
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
79
+ x = self.linear(x)
80
+ return x
81
+
82
+
83
+ class TimestepEmbeddingEncoder(nn.Module):
84
+ def __init__(
85
+ self,
86
+ embedding_dim: int,
87
+ feat_dim: int,
88
+ act_type: str = "silu",
89
+ time_factor: float = 1.0,
90
+ ) -> None:
91
+ super(TimestepEmbeddingEncoder, self).__init__()
92
+
93
+ self.embedding_dim = embedding_dim
94
+ self.feat_dim = feat_dim
95
+ self.time_factor = time_factor
96
+ blocks = [
97
+ nn.Linear(embedding_dim, self.feat_dim),
98
+ get_activation_layer(act_type)(),
99
+ nn.Linear(self.feat_dim, self.feat_dim),
100
+ ]
101
+ self.blocks = nn.Sequential(*blocks)
102
+
103
+ def forward(self, t: Tensor) -> Tensor:
104
+ x = self.blocks(self.sinusodial_embedding(t, self.embedding_dim, time_factor=self.time_factor)).unsqueeze(1)
105
+ return x
106
+
107
+ @staticmethod
108
+ def sinusodial_embedding(
109
+ timesteps: Tensor, embedding_dim: int, temperature: float = 10000.0, time_factor: float = 1.0
110
+ ) -> Tensor:
111
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
112
+ timesteps = timesteps * time_factor
113
+ half = embedding_dim // 2
114
+ freqs = torch.exp(
115
+ -torch.log(torch.tensor(temperature)) * torch.arange(start=0, end=half, dtype=torch.float) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
119
+ if embedding_dim % 2:
120
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
hymotion/network/hymotion_mmdit.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from ..utils.loaders import load_object
12
+ from ..utils.type_converter import get_module_device
13
+ from .attention import attention
14
+ from .bricks import get_activation_layer, get_norm_layer
15
+ from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
16
+ from .modulate_layers import ModulateDiT, apply_gate, modulate
17
+ from .positional_encoding import RotaryEmbedding
18
+
19
+
20
+ class MMBaseBlock(nn.Module):
21
+ def __init__(
22
+ self,
23
+ feat_dim: int,
24
+ num_heads: int,
25
+ mlp_ratio: float,
26
+ dropout: float,
27
+ positional_encoding_cfg: dict,
28
+ apply_rope_to_single_branch: bool,
29
+ ):
30
+ super().__init__()
31
+ self.feat_dim = feat_dim
32
+ self.num_heads = num_heads
33
+ self.mlp_ratio = mlp_ratio
34
+ self.dropout = dropout
35
+
36
+ assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
37
+ self.head_dim = self.feat_dim // num_heads
38
+
39
+ self.mlp_hidden_dim = int(self.feat_dim * mlp_ratio)
40
+
41
+ self._positional_encoding_cfg = positional_encoding_cfg.copy()
42
+ self.rotary_emb = RotaryEmbedding(num_feats=self.head_dim, **self._positional_encoding_cfg)
43
+ self.apply_rope_to_single_branch = apply_rope_to_single_branch
44
+
45
+
46
+ class MMDoubleStreamBlock(MMBaseBlock):
47
+ def __init__(
48
+ self,
49
+ feat_dim: int,
50
+ num_heads: int,
51
+ mlp_ratio: float,
52
+ dropout: float,
53
+ mlp_act_type: str,
54
+ qk_norm_type: Optional[str] = None,
55
+ qkv_bias: bool = False,
56
+ positional_encoding_cfg: dict = {
57
+ "max_seq_len": 5000,
58
+ "use_real": True,
59
+ },
60
+ apply_rope_to_single_branch: bool = True,
61
+ ):
62
+ super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
63
+
64
+ self.motion_mod = ModulateDiT(
65
+ self.feat_dim,
66
+ factor=6,
67
+ act_type="silu",
68
+ )
69
+ self.motion_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
70
+
71
+ motion_qkv_out_dim = self.feat_dim * 3
72
+ self.motion_qkv = nn.Linear(self.feat_dim, motion_qkv_out_dim, bias=qkv_bias)
73
+
74
+ self.motion_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
75
+ self.motion_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
76
+ self.motion_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
77
+ self.motion_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
78
+ self.motion_mlp = MLP(
79
+ self.feat_dim,
80
+ self.mlp_hidden_dim,
81
+ act_type=mlp_act_type,
82
+ bias=True,
83
+ )
84
+
85
+ self.text_mod = ModulateDiT(
86
+ self.feat_dim,
87
+ factor=6,
88
+ act_type="silu",
89
+ )
90
+ self.text_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
91
+
92
+ text_qkv_out_dim = self.feat_dim * 3
93
+ self.text_qkv = nn.Linear(self.feat_dim, text_qkv_out_dim, bias=qkv_bias)
94
+
95
+ self.text_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
96
+ self.text_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
97
+ self.text_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
98
+ self.text_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
99
+ self.text_mlp = MLP(
100
+ self.feat_dim,
101
+ self.mlp_hidden_dim,
102
+ act_type=mlp_act_type,
103
+ bias=True,
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ motion_feat: Tensor,
109
+ text_feat: Tensor,
110
+ adapter: Tensor,
111
+ attn_mask: Optional[Tensor] = None,
112
+ ) -> Tuple[Tensor, Tensor]:
113
+ (
114
+ motion_shift_msa,
115
+ motion_scale_msa,
116
+ motion_gate_msa,
117
+ motion_shift_mlp,
118
+ motion_scale_mlp,
119
+ motion_gate_mlp,
120
+ ) = self.motion_mod(adapter).chunk(6, dim=-1)
121
+ (
122
+ text_shift_msa,
123
+ text_scale_msa,
124
+ text_gate_msa,
125
+ text_shift_mlp,
126
+ text_scale_mlp,
127
+ text_gate_mlp,
128
+ ) = self.text_mod(
129
+ adapter
130
+ ).chunk(6, dim=-1)
131
+
132
+ motion_modulated = self.motion_norm1(motion_feat)
133
+ motion_modulated = modulate(motion_modulated, shift=motion_shift_msa, scale=motion_scale_msa)
134
+ motion_qkv = self.motion_qkv(motion_modulated)
135
+
136
+ motion_q, motion_k, motion_v = rearrange(motion_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
137
+ motion_q = self.motion_q_norm(motion_q).to(motion_v)
138
+ motion_k = self.motion_k_norm(motion_k).to(motion_v)
139
+
140
+ if self.apply_rope_to_single_branch:
141
+ # NOTE: we don't apply RoPE to text_branch_two here
142
+ motion_q, motion_k = self.rotary_emb.apply_rotary_emb(motion_q, motion_k)
143
+
144
+ text_modulated = self.text_norm1(text_feat)
145
+ text_modulated = modulate(text_modulated, shift=text_shift_msa, scale=text_scale_msa)
146
+ text_qkv = self.text_qkv(text_modulated)
147
+
148
+ text_q, text_k, text_v = rearrange(
149
+ text_qkv,
150
+ "B L (K H D) -> K B L H D",
151
+ K=3,
152
+ H=self.num_heads,
153
+ )
154
+ text_q = self.text_q_norm(text_q).to(text_v)
155
+ text_k = self.text_k_norm(text_k).to(text_v)
156
+
157
+ q = torch.cat((motion_q, text_q), dim=1)
158
+ k = torch.cat((motion_k, text_k), dim=1)
159
+ v = torch.cat((motion_v, text_v), dim=1)
160
+
161
+ if not self.apply_rope_to_single_branch:
162
+ q, k = self.rotary_emb.apply_rotary_emb(q, k)
163
+
164
+ bsz, total_len, _, _ = q.shape
165
+ motion_len = motion_feat.shape[1]
166
+ text_len = text_feat.shape[1]
167
+ dropout_p = 0.0 if not self.training else self.dropout
168
+
169
+ attn_output = attention(
170
+ q,
171
+ k,
172
+ v,
173
+ mode="torch",
174
+ drop_rate=dropout_p,
175
+ attn_mask=attn_mask,
176
+ causal=False,
177
+ cu_seqlens_q=None,
178
+ cu_seqlens_kv=None,
179
+ max_seqlen_q=None,
180
+ max_seqlen_kv=None,
181
+ batch_size=bsz,
182
+ training=self.training,
183
+ )
184
+
185
+ motion_attn_output, text_attn_output = (
186
+ attn_output[:, :motion_len, ...],
187
+ attn_output[:, motion_len:, ...],
188
+ )
189
+
190
+ motion_feat = motion_feat + apply_gate(self.motion_out_proj(motion_attn_output), gate=motion_gate_msa)
191
+ motion_feat = motion_feat + apply_gate(
192
+ self.motion_mlp(
193
+ modulate(
194
+ self.motion_norm2(motion_feat),
195
+ shift=motion_shift_mlp,
196
+ scale=motion_scale_mlp,
197
+ )
198
+ ),
199
+ gate=motion_gate_mlp,
200
+ )
201
+
202
+ text_feat = text_feat + apply_gate(self.text_out_proj(text_attn_output), gate=text_gate_msa)
203
+ text_feat = text_feat + apply_gate(
204
+ self.text_mlp(
205
+ modulate(
206
+ self.text_norm2(text_feat),
207
+ shift=text_shift_mlp,
208
+ scale=text_scale_mlp,
209
+ )
210
+ ),
211
+ gate=text_gate_mlp,
212
+ )
213
+
214
+ return motion_feat, text_feat
215
+
216
+
217
+ class MMSingleStreamBlock(MMBaseBlock):
218
+ def __init__(
219
+ self,
220
+ feat_dim: int,
221
+ num_heads: int,
222
+ mlp_ratio: float,
223
+ dropout: float,
224
+ mlp_act_type: str,
225
+ qk_norm_type: Optional[str] = None,
226
+ qkv_bias: bool = False,
227
+ positional_encoding_cfg: dict = {
228
+ "max_seq_len": 5000,
229
+ "use_real": True,
230
+ },
231
+ apply_rope_to_single_branch: bool = True,
232
+ ):
233
+ super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
234
+
235
+ self.modulation = ModulateDiT(self.feat_dim, factor=3, act_type="silu")
236
+ self.norm = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
237
+
238
+ # qkv and mlp_in
239
+ qkv_factor = 3
240
+ self.linear1 = nn.Linear(self.feat_dim, self.feat_dim * qkv_factor + self.mlp_hidden_dim, bias=qkv_bias)
241
+ # proj and mlp_out
242
+ self.linear2 = nn.Linear(self.feat_dim + self.mlp_hidden_dim, self.feat_dim, bias=qkv_bias)
243
+
244
+ self.q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
245
+ self.k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
246
+
247
+ self.mlp_act = get_activation_layer(mlp_act_type)()
248
+
249
+ def forward(
250
+ self,
251
+ x: Tensor,
252
+ split_len: int,
253
+ adapter: Tensor,
254
+ attn_mask: Optional[Tensor] = None,
255
+ ) -> Tensor:
256
+ (
257
+ shift_msa,
258
+ scale_msa,
259
+ gate_msa,
260
+ ) = self.modulation(
261
+ adapter
262
+ ).chunk(3, dim=-1)
263
+ x_modulated = modulate(self.norm(x), shift_msa, scale_msa)
264
+
265
+ qkv, mlp_hidden = torch.split(self.linear1(x_modulated), [3 * self.feat_dim, self.mlp_hidden_dim], dim=-1)
266
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
267
+
268
+ q = self.q_norm(q).to(v)
269
+ k = self.k_norm(k).to(v)
270
+
271
+ q1, q2 = q[:, :split_len, ...], q[:, split_len:, ...]
272
+ k1, k2 = k[:, :split_len, ...], k[:, split_len:, ...]
273
+ # apply rotary position embedding
274
+ if self.apply_rope_to_single_branch:
275
+ q1, k1 = self.rotary_emb.apply_rotary_emb(q1, k1)
276
+ q = torch.cat((q1, q2), dim=1)
277
+ k = torch.cat((k1, k2), dim=1)
278
+ if not self.apply_rope_to_single_branch:
279
+ q, k = self.rotary_emb.apply_rotary_emb(q, k)
280
+
281
+ bsz, total_len = x_modulated.shape[:2]
282
+ dropout_p = 0.0 if not self.training else self.dropout
283
+
284
+ attn_output = attention(
285
+ q,
286
+ k,
287
+ v,
288
+ mode="torch",
289
+ drop_rate=dropout_p,
290
+ attn_mask=attn_mask,
291
+ causal=False,
292
+ cu_seqlens_q=None,
293
+ cu_seqlens_kv=None,
294
+ max_seqlen_q=None,
295
+ max_seqlen_kv=None,
296
+ batch_size=bsz,
297
+ training=self.training,
298
+ )
299
+ output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp_hidden)), 2))
300
+
301
+ return x + apply_gate(output, gate=gate_msa)
302
+
303
+
304
+ class HunyuanMotionMMDiT(nn.Module):
305
+ def __init__(
306
+ self,
307
+ input_dim: int,
308
+ feat_dim: int,
309
+ output_dim: Optional[int] = None,
310
+ ctxt_input_dim: int = 4096,
311
+ vtxt_input_dim: int = 256,
312
+ text_refiner_module: str = "hymotion/network/token_refiner.SingleTokenRefiner",
313
+ text_refiner_cfg: dict = {
314
+ "num_layers": 2,
315
+ },
316
+ num_layers: int = 12,
317
+ num_heads: int = 16,
318
+ mlp_ratio: float = 4.0,
319
+ mlp_act_type: str = "gelu_tanh",
320
+ norm_type: str = "layer",
321
+ qk_norm_type: str = "rms",
322
+ qkv_bias: bool = True,
323
+ dropout: float = 0.0,
324
+ final_layer_module: str = "hymotion/network/encoders.FinalLayer",
325
+ final_layer_cfg: dict = {
326
+ "act_type": "silu",
327
+ },
328
+ mask_mode: Optional[str] = None,
329
+ apply_rope_to_single_branch: bool = True,
330
+ insert_start_token: bool = False,
331
+ with_long_skip_connection: bool = False,
332
+ time_factor: float = 1.0,
333
+ narrowband_length: float = 2.0,
334
+ **kwargs,
335
+ ):
336
+ super().__init__()
337
+ self.motion_input_dim = input_dim
338
+ self.ctxt_input_dim = ctxt_input_dim
339
+ self.vtxt_input_dim = vtxt_input_dim
340
+ self.feat_dim = feat_dim
341
+ self.output_dim = output_dim or input_dim
342
+ self.mask_mode = mask_mode
343
+ self.insert_start_token = insert_start_token
344
+ self.time_factor = time_factor
345
+ self.narrowband_length = narrowband_length * 30.0
346
+ if self.insert_start_token:
347
+ self.start_token = nn.Parameter(torch.randn(1, feat_dim))
348
+ self.with_long_skip_connection = with_long_skip_connection
349
+ if self.with_long_skip_connection:
350
+ from .encoders import FinalLayer
351
+
352
+ self.long_skip_net = FinalLayer(feat_dim=feat_dim, out_dim=feat_dim, act_type="silu")
353
+
354
+ self.input_encoder = nn.Linear(in_features=input_dim, out_features=feat_dim)
355
+ self.ctxt_encoder = nn.Linear(in_features=ctxt_input_dim, out_features=feat_dim)
356
+ self.vtxt_encoder = MLPEncoder(in_dim=vtxt_input_dim, feat_dim=feat_dim, num_layers=2, act_type="silu")
357
+ self.timestep_encoder = TimestepEmbeddingEncoder(
358
+ embedding_dim=feat_dim,
359
+ feat_dim=feat_dim,
360
+ time_factor=time_factor,
361
+ )
362
+
363
+ if text_refiner_module != "" and text_refiner_module is not None:
364
+ text_refiner_cfg.update(input_dim=feat_dim, feat_dim=feat_dim, num_heads=num_heads)
365
+ self._text_refiner_cfg = text_refiner_cfg.copy()
366
+ self.text_refiner = load_object(text_refiner_module, text_refiner_cfg)
367
+
368
+ self.num_layers = num_layers
369
+ assert num_layers % 3 == 0, f"num_layers must be divisible by 3, but got {num_layers}"
370
+ self.mm_double_blocks_layers = int(num_layers // 3)
371
+ self.mm_single_blocks_layers = int(num_layers - num_layers // 3)
372
+
373
+ self.double_blocks = nn.ModuleList(
374
+ [
375
+ MMDoubleStreamBlock(
376
+ feat_dim=feat_dim,
377
+ num_heads=num_heads,
378
+ mlp_ratio=mlp_ratio,
379
+ dropout=dropout,
380
+ mlp_act_type=mlp_act_type,
381
+ qk_norm_type=qk_norm_type,
382
+ qkv_bias=qkv_bias,
383
+ apply_rope_to_single_branch=apply_rope_to_single_branch,
384
+ )
385
+ for _ in range(self.mm_double_blocks_layers)
386
+ ]
387
+ )
388
+
389
+ self.single_blocks = nn.ModuleList(
390
+ [
391
+ MMSingleStreamBlock(
392
+ feat_dim=feat_dim,
393
+ num_heads=num_heads,
394
+ mlp_ratio=mlp_ratio,
395
+ dropout=dropout,
396
+ mlp_act_type=mlp_act_type,
397
+ qk_norm_type=qk_norm_type,
398
+ qkv_bias=qkv_bias,
399
+ apply_rope_to_single_branch=apply_rope_to_single_branch,
400
+ )
401
+ for _ in range(self.mm_single_blocks_layers)
402
+ ]
403
+ )
404
+
405
+ final_layer_cfg.update(feat_dim=feat_dim, out_dim=self.output_dim)
406
+ self._final_layer_cfg = final_layer_cfg.copy()
407
+ self.final_layer = load_object(final_layer_module, final_layer_cfg)
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ ctxt_input: Tensor,
413
+ vtxt_input: Tensor,
414
+ timesteps: Tensor,
415
+ x_mask_temporal: Tensor,
416
+ ctxt_mask_temporal: Tensor,
417
+ **kwargs,
418
+ ) -> Tensor:
419
+ device = get_module_device(self)
420
+
421
+ motion_feat = self.input_encoder(x)
422
+ if self.with_long_skip_connection:
423
+ origin_feat = motion_feat
424
+ if self.insert_start_token:
425
+ # (B, 1, D) + (B, L, D) -> (B, L+1, D)
426
+ start_token = self.start_token[None].repeat(motion_feat.shape[0], 1, 1)
427
+ motion_feat = torch.cat((start_token, motion_feat), dim=1)
428
+ x_mask_temporal = torch.cat(
429
+ [
430
+ torch.ones_like(x_mask_temporal[:, :1], dtype=torch.bool),
431
+ x_mask_temporal,
432
+ ],
433
+ dim=1,
434
+ )
435
+
436
+ timestep_feat = self.timestep_encoder(timesteps)
437
+ vtxt_feat = self.vtxt_encoder(vtxt_input.float())
438
+ adapter = timestep_feat + vtxt_feat
439
+
440
+ motion_key_padding_mask = self._canonical_mask(x_mask_temporal).to(device)
441
+ ctxt_key_padding_mask = self._canonical_mask(ctxt_mask_temporal).to(device)
442
+ seq_key_padding_mask = torch.cat((motion_key_padding_mask, ctxt_key_padding_mask), dim=1)
443
+ if self.mask_mode is None:
444
+ seq_mask = None
445
+ elif self.mask_mode == "causal":
446
+ motion_len = motion_feat.shape[1]
447
+ seq_mask = torch.triu(
448
+ torch.full((motion_len, motion_len), float("-inf"), device=device),
449
+ diagonal=1,
450
+ )
451
+ elif self.mask_mode == "narrowband":
452
+ window = int(round(self.narrowband_length))
453
+ motion_len = motion_feat.shape[1]
454
+ idx = torch.arange(motion_len, device=device)
455
+ dist = (idx[None, :] - idx[:, None]).abs()
456
+ band = dist <= window
457
+ seq_mask = torch.full((motion_len, motion_len), float("-inf"), device=device)
458
+ seq_mask = seq_mask.masked_fill(band, 0.0)
459
+ else:
460
+ raise ValueError(f"Unsupported mask mode: {self.mask_mode}")
461
+
462
+ ctxt_feat = self.ctxt_encoder(ctxt_input.float())
463
+ if hasattr(self, "text_refiner"):
464
+ ctxt_feat = self.text_refiner(x=ctxt_feat, t=timesteps, mask=(ctxt_key_padding_mask == 0).to(device))
465
+
466
+ # precompute shared attention masks (broadcastable over heads)
467
+ bsz = x.shape[0]
468
+ motion_len = motion_feat.shape[1]
469
+ text_len = ctxt_feat.shape[1]
470
+ total_len = motion_len + text_len
471
+ mask_dtype = motion_feat.dtype
472
+ attn_mask_double = self._build_dmm_attn_mask_shared(
473
+ bsz=bsz,
474
+ motion_len=motion_len,
475
+ text_len=text_len,
476
+ dtype=mask_dtype,
477
+ key_padding_mask=seq_key_padding_mask,
478
+ attn_mask=seq_mask,
479
+ device=device,
480
+ )
481
+ for i_layer, mod in enumerate(self.double_blocks):
482
+ motion_feat, ctxt_feat = mod(
483
+ motion_feat=motion_feat,
484
+ text_feat=ctxt_feat,
485
+ adapter=adapter,
486
+ attn_mask=attn_mask_double,
487
+ )
488
+
489
+ # precompute shared attention masks for single stream blocks too
490
+ split_len = motion_feat.shape[1]
491
+ x = torch.cat((motion_feat, ctxt_feat), 1)
492
+ attn_mask_single = self._build_smm_attn_mask_shared(
493
+ bsz=bsz,
494
+ split_len=split_len,
495
+ total_len=total_len,
496
+ dtype=mask_dtype,
497
+ key_padding_mask=seq_key_padding_mask,
498
+ attn_mask=seq_mask,
499
+ device=device,
500
+ )
501
+ for i_layer, mod in enumerate(self.single_blocks):
502
+ x = mod(
503
+ x=x,
504
+ split_len=split_len,
505
+ adapter=adapter,
506
+ attn_mask=attn_mask_single,
507
+ )
508
+
509
+ x = x[:, :split_len, ...]
510
+ if self.insert_start_token:
511
+ x = x[:, 1:, ...]
512
+
513
+ if self.with_long_skip_connection:
514
+ # long skip only consider timestep_feat
515
+ x = self.long_skip_net(origin_feat, timestep_feat) + x
516
+
517
+ predicted_res = self.final_layer(x, adapter)
518
+ return predicted_res
519
+
520
+ @staticmethod
521
+ def _canonical_mask(input_mask: Tensor) -> Tensor:
522
+ if input_mask.ndim == 1:
523
+ input_mask = input_mask.unsqueeze(1)
524
+ key_padding_mask = torch.where(
525
+ input_mask,
526
+ torch.zeros_like(input_mask, dtype=torch.float),
527
+ torch.full_like(input_mask, float("-inf"), dtype=torch.float),
528
+ )
529
+ return key_padding_mask
530
+
531
+ def _build_dmm_attn_mask_shared(
532
+ self,
533
+ bsz: int,
534
+ motion_len: int,
535
+ text_len: int,
536
+ dtype: torch.dtype,
537
+ key_padding_mask: Optional[Tensor],
538
+ attn_mask: Optional[Tensor],
539
+ device: torch.device,
540
+ ) -> Tensor:
541
+ """
542
+ NOTE:
543
+ motion_k text_k
544
+ motion_q [M→M] [M→T]
545
+ text_q [T→M] [T→T]
546
+ only [M→M] contains given mask
547
+ """
548
+ total_len = motion_len + text_len
549
+ base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
550
+ if attn_mask is not None:
551
+ if attn_mask.dim() != 2 or attn_mask.shape != (motion_len, motion_len):
552
+ raise RuntimeError(
553
+ f"attn_mask should be 2D with shape {(motion_len, motion_len)}, got {attn_mask.shape}"
554
+ )
555
+ base[:, :, :motion_len, :motion_len] += attn_mask.view(1, 1, motion_len, motion_len)
556
+ if key_padding_mask is not None:
557
+ mask_total_len = key_padding_mask.shape[1]
558
+ if mask_total_len == motion_len:
559
+ pad = torch.zeros((bsz, text_len), dtype=key_padding_mask.dtype, device=device)
560
+ key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
561
+ base = base + key_padding_mask.view(bsz, 1, 1, total_len)
562
+ # disable T→M
563
+ base[:, :, motion_len:, :motion_len] = float("-inf")
564
+ return base
565
+
566
+ def _build_smm_attn_mask_shared(
567
+ self,
568
+ bsz: int,
569
+ split_len: int,
570
+ total_len: int,
571
+ dtype: torch.dtype,
572
+ key_padding_mask: Optional[Tensor],
573
+ attn_mask: Optional[Tensor],
574
+ device: torch.device,
575
+ ) -> Tensor:
576
+ """
577
+ NOTE:
578
+ motion_k text_k
579
+ motion_q [M→M] [M→T]
580
+ text_q [T→M] [T→T]
581
+ only [M→M] contains given mask
582
+ """
583
+ base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
584
+ if attn_mask is not None:
585
+ if attn_mask.dim() != 2 or attn_mask.shape != (split_len, split_len):
586
+ raise RuntimeError(f"attn_mask should be 2D with shape {(split_len, split_len)}, got {attn_mask.shape}")
587
+ base[:, :, :split_len, :split_len] += attn_mask.view(1, 1, split_len, split_len)
588
+ if key_padding_mask is not None:
589
+ mask_total_len = key_padding_mask.shape[1]
590
+ if mask_total_len == split_len:
591
+ pad = torch.zeros(
592
+ (bsz, total_len - split_len),
593
+ dtype=key_padding_mask.dtype,
594
+ device=device,
595
+ )
596
+ key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
597
+ base = base + key_padding_mask.view(bsz, 1, 1, total_len)
598
+ # disable T→M
599
+ base[:, :, split_len:, :split_len] = float("-inf")
600
+ return base
601
+
602
+
603
+ if __name__ == "__main__":
604
+ # python -m hymotion.network.hymotion_mmdit
605
+
606
+ from configs._base_.model_network_base import MOTION_MODEL_CONFIG # pyright: ignore
607
+
608
+ network_module_cfg = MOTION_MODEL_CONFIG["1.04B_narrowband"]["network_module_args"]
609
+ network_module_cfg = dict(network_module_cfg) # convert to normal dict
610
+
611
+ bsz, seq_len, text_seq_len, input_dim = 1, 360, 128, 201
612
+ network_module_cfg["input_dim"] = input_dim
613
+ MMDiT = HunyuanMotionMMDiT(**network_module_cfg)
614
+
615
+ x = torch.randn(bsz, seq_len, input_dim)
616
+ ctxt_condition = torch.randn(bsz, text_seq_len, 4096)
617
+ vtxt_condition = torch.randn(bsz, 1, 768)
618
+ timesteps = torch.randint(0, 1000, (bsz,))
619
+ length = torch.arange(seq_len).unsqueeze(0).repeat(bsz, 1)
620
+ ctxt_length = torch.arange(text_seq_len).unsqueeze(0).repeat(bsz, 1)
621
+ x_mask_temporal = length < 100
622
+ ctxt_mask_temporal = ctxt_length < 50
623
+ x = MMDiT(
624
+ x=x,
625
+ ctxt_input=ctxt_condition,
626
+ vtxt_input=vtxt_condition,
627
+ timesteps=timesteps,
628
+ x_mask_temporal=x_mask_temporal,
629
+ ctxt_mask_temporal=ctxt_mask_temporal,
630
+ )
631
+ assert x.shape == (
632
+ bsz,
633
+ seq_len,
634
+ input_dim,
635
+ ), f"unexpected output shape: {x.shape}, which should be ({bsz}, {seq_len}, {input_dim})"
636
+ print(x.shape)
hymotion/network/modulate_layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+
6
+ from .bricks import get_activation_layer
7
+
8
+
9
+ class ModulateDiT(nn.Module):
10
+ def __init__(self, feat_dim: int, factor: int, act_type: str = "silu"):
11
+ super().__init__()
12
+ self.act = get_activation_layer(act_type)()
13
+ self.linear = nn.Linear(feat_dim, factor * feat_dim, bias=True)
14
+ nn.init.zeros_(self.linear.weight)
15
+ nn.init.zeros_(self.linear.bias)
16
+
17
+ def forward(self, x: Tensor) -> Tensor:
18
+ return self.linear(self.act(x))
19
+
20
+
21
+ def modulate(x: Tensor, shift: Optional[Tensor] = None, scale: Optional[Tensor] = None) -> Tensor:
22
+ if shift is not None and scale is not None:
23
+ assert len(x.shape) == len(shift.shape) == len(scale.shape), (
24
+ "x, shift, scale must have the same number of dimensions, "
25
+ f"but got x.shape: {x.shape}, "
26
+ f"shift.shape: {shift.shape} "
27
+ f"and scale.shape: {scale.shape}"
28
+ )
29
+ if shift is not None and scale is not None:
30
+ return x * (1 + scale) + shift
31
+ elif shift is not None:
32
+ return x + shift
33
+ elif scale is not None:
34
+ return x * (1 + scale)
35
+ else:
36
+ return x
37
+
38
+
39
+ def apply_gate(x: Tensor, gate: Optional[Tensor] = None, tanh: bool = False) -> Tensor:
40
+ if gate is not None:
41
+ assert len(x.shape) == len(
42
+ gate.shape
43
+ ), f"x, gate must have the same number of dimensions, but got {x.shape} and {gate.shape}"
44
+ if gate is None:
45
+ return x
46
+ if tanh:
47
+ return x * gate.tanh()
48
+ else:
49
+ return x * gate
hymotion/network/positional_encoding.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+
9
+ class RotaryEmbedding(nn.Module):
10
+ def __init__(
11
+ self,
12
+ num_feats: int,
13
+ max_seq_len: Union[Tensor, int],
14
+ temperature: int = 10000,
15
+ use_real: bool = False,
16
+ theta_rescale_factor: float = 1.0,
17
+ interpolation_factor: float = 1.0,
18
+ ) -> None:
19
+ super(RotaryEmbedding, self).__init__()
20
+ assert num_feats % 2 == 0, "num_feats (head_dim) must be even for RoPE."
21
+ self.num_feats = num_feats
22
+ self.max_seq_len = max_seq_len
23
+ self.temperature = temperature
24
+ self.use_real = use_real
25
+ self.theta_rescale_factor = theta_rescale_factor
26
+ self.interpolation_factor = interpolation_factor
27
+
28
+ if isinstance(max_seq_len, int):
29
+ max_seq_len = torch.arange(max_seq_len).float()
30
+
31
+ if theta_rescale_factor != 1.0:
32
+ temperature *= theta_rescale_factor ** (self.num_feats / (self.num_feats - 2))
33
+ dim_t = torch.arange(0, self.num_feats, 2, dtype=torch.float32)
34
+ freqs = 1.0 / (temperature ** (2 * torch.div(dim_t, 2, rounding_mode="trunc") / self.num_feats)) # [D/2]
35
+ freqs = torch.outer(max_seq_len.float() * interpolation_factor, freqs) # [S, D/2]
36
+ if use_real:
37
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
38
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
39
+ self.freqs_cis = (freqs_cos, freqs_sin)
40
+ else:
41
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
42
+ self.freqs_cis = freqs_cis
43
+
44
+ def reshape_for_broadcast(
45
+ self, freqs_cis: Union[Tensor, Tuple[Tensor, Tensor]], x: Tensor
46
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
47
+ ndim = x.ndim
48
+ assert 0 <= 1 < ndim
49
+
50
+ if isinstance(freqs_cis, tuple):
51
+ # freqs_cis: (cos, sin) in real space
52
+ assert (
53
+ freqs_cis[0].shape[-1] == x.shape[-1]
54
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape} on the head_dim dimension"
55
+ assert freqs_cis[0].shape[0] >= x.shape[1], (
56
+ f"freqs_cis shape {freqs_cis[0].shape} should be larger than or equal to "
57
+ f"x shape {x.shape} on the time dimension"
58
+ )
59
+ shape = []
60
+ for i, d in enumerate(x.shape):
61
+ if i == 1:
62
+ shape.append(-1)
63
+ elif i == ndim - 1:
64
+ shape.append(d)
65
+ else:
66
+ shape.append(1)
67
+ return (
68
+ freqs_cis[0].view(*shape)[:, : x.shape[1], ...],
69
+ freqs_cis[1].view(*shape)[:, : x.shape[1], ...],
70
+ )
71
+ else:
72
+ # freqs_cis: values in complex space
73
+ assert (
74
+ freqs_cis.shape[-1] == x.shape[-1]
75
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape} on the head_dim dimension"
76
+ assert freqs_cis.shape[0] >= x.shape[1], (
77
+ f"freqs_cis shape {freqs_cis.shape} should be larger than or equal to "
78
+ f"x shape {x.shape} on the time dimension"
79
+ )
80
+ shape = []
81
+ for i, d in enumerate(x.shape):
82
+ if i == 1:
83
+ shape.append(-1)
84
+ elif i == ndim - 1:
85
+ shape.append(d)
86
+ else:
87
+ shape.append(1)
88
+ return freqs_cis.view(*shape)[:, : x.shape[1], ...]
89
+
90
+ def rotate_half(self, x: Tensor) -> Tensor:
91
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
92
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
93
+
94
+ def apply_rotary_emb(self, xq: Tensor, xk: Tensor) -> Tuple[Tensor, Tensor]:
95
+ xk_out = None
96
+ if isinstance(self.freqs_cis, tuple):
97
+ cos, sin = self.reshape_for_broadcast(self.freqs_cis, xq) # [B, L, H, D]
98
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
99
+ # real * cos - imag * sin
100
+ # imag * cos + real * sin
101
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
102
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
103
+ else:
104
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
105
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
106
+ freqs_cis = self.reshape_for_broadcast(self.freqs_cis, xq_)
107
+ # Handle device transfer based on return type
108
+ if isinstance(freqs_cis, tuple):
109
+ freqs_cis = (freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device))
110
+ else:
111
+ freqs_cis = freqs_cis.to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
112
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
113
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
114
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
115
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
116
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
117
+
118
+ return xq_out, xk_out
119
+
120
+ def __repr__(self) -> str:
121
+ repr_str = self.__class__.__name__
122
+ repr_str += f"(num_feats={self.num_feats}, "
123
+ repr_str += f"max_seq_len={self.max_seq_len}, "
124
+ repr_str += f"temperature={self.temperature}, "
125
+ repr_str += f"use_real={self.use_real}, "
126
+ repr_str += f"theta_rescale_factor={self.theta_rescale_factor}, "
127
+ repr_str += f"interpolation_factor={self.interpolation_factor})"
128
+ return repr_str
129
+
130
+
131
+ class PositionalEncoding(nn.Module):
132
+ def __init__(self, num_feats: int, dropout: float = 0.1, max_len: int = 5000):
133
+ super(PositionalEncoding, self).__init__()
134
+ self.dropout = nn.Dropout(p=dropout)
135
+
136
+ pe = torch.zeros(max_len, num_feats)
137
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
138
+ div_term = torch.exp(torch.arange(0, num_feats, 2).float() * (-np.log(10000.0) / num_feats))
139
+ pe[:, 0::2] = torch.sin(position * div_term)
140
+ pe[:, 1::2] = torch.cos(position * div_term)
141
+ pe = pe.unsqueeze(0) # shape of [1, L, D]
142
+ self.register_buffer("pe", pe)
143
+
144
+ def forward(self, x: Tensor) -> Tensor:
145
+ x = x + self.pe[:, : x.shape[1], :] # shape of [B, L, D]
146
+ return self.dropout(x)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ # python -m hymotion.network.positional_encoding
151
+ num_feats = 32
152
+ rope = RotaryEmbedding(num_feats=num_feats, max_seq_len=5000, use_real=True)
153
+ x = torch.ones(1, 360, 1, num_feats)
154
+ text = torch.ones(1, 256, 1, num_feats)
155
+ q1, k1 = x.clone(), x.clone()
156
+ q2, k2 = text.clone(), text.clone()
157
+ print(x.shape)
158
+ # q1, k1 = rope.apply_rotary_emb(q1, k1)
159
+ # q2, k2 = rope.apply_rotary_emb(q2, k2)
160
+ q = torch.cat([q1, q2], dim=1)
161
+ k = torch.cat([k1, k2], dim=1)
162
+ q, k = rope.apply_rotary_emb(q, k)
163
+ q, k = q[0, :, 0, :], k[0, :, 0, :]
164
+ attn = (q[:, None] * k[None, :]).sum(dim=-1)
165
+ # softmax
166
+ # attn = torch.softmax(attn, dim=-1)
167
+ attn = attn.cpu().numpy()
168
+
169
+ import matplotlib.pyplot as plt
170
+
171
+ plt.imshow(attn, cmap="hot")
172
+ plt.colorbar()
173
+ plt.savefig("attn.png")
174
+ breakpoint()
hymotion/network/text_encoders/model_constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION",
3
+ ]
4
+
5
+
6
+ PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION = """
7
+ Summarize human motion only from the user text for representation: action categories, key body-part movements, order/transitions, trajectory/direction, posture; include style/emotion/speed only if present. Explicitly capture laterality (left/right) when mentioned; do not guess. If multiple actions are described, indicate the count of distinct actions (e.g., actions=3) and their order. Do not invent missing info. Keep one concise paragraph.
8
+ """
hymotion/network/text_encoders/text_encoder.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ )
13
+
14
+ from ...utils.type_converter import get_module_device
15
+ from .model_constants import PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION
16
+
17
+ USE_HF_MODELS = os.environ.get("USE_HF_MODELS", "0") == "1"
18
+
19
+ if USE_HF_MODELS:
20
+ QWEN_PATH = "Qwen/Qwen3-8B"
21
+ CLIP_PATH = "openai/clip-vit-large-patch14"
22
+ else:
23
+ QWEN_PATH = "ckpts/Qwen3-8B"
24
+ CLIP_PATH = "ckpts/clip-vit-large-patch14"
25
+
26
+ LLM_ENCODER_LAYOUT = {
27
+ "qwen3": {
28
+ "module_path": QWEN_PATH,
29
+ "template": [
30
+ {"role": "system", "content": f"{PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION}"},
31
+ {"role": "user", "content": "{}"},
32
+ ],
33
+ "crop_start": 0,
34
+ "tokenizer_class": AutoTokenizer,
35
+ "text_encoder_class": AutoModelForCausalLM,
36
+ },
37
+ }
38
+
39
+ SENTENCE_EMB_LAYOUT = {
40
+ "clipl": {
41
+ "module_path": CLIP_PATH,
42
+ "tokenizer_class": CLIPTokenizer,
43
+ "text_encoder_class": CLIPTextModel,
44
+ "pooling_mode": "pooler_output",
45
+ "max_length": 77,
46
+ },
47
+ }
48
+
49
+
50
+ class HYTextModel(nn.Module):
51
+ def __init__(
52
+ self,
53
+ llm_type: Optional[str] = "qwen3",
54
+ max_length_llm: int = 512,
55
+ sentence_emb_type: Optional[str] = "clipl",
56
+ max_length_sentence_emb: int = 77,
57
+ enable_llm_padding: bool = True,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.text_encoder_type = "hy_text_model"
61
+
62
+ self.sentence_emb_type = sentence_emb_type
63
+ self.sentence_emb_text_encoder = None
64
+ self.sentence_emb_tokenizer = None
65
+ self.vtxt_dim = 0
66
+ if sentence_emb_type is not None:
67
+ assert sentence_emb_type in SENTENCE_EMB_LAYOUT, f"Unsupported sentence embedding type: {sentence_emb_type}"
68
+ self.max_length_sentence_emb = max_length_sentence_emb or SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
69
+ "max_length", 77
70
+ )
71
+ self._sentence_emb_pooling_mode = SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
72
+ "pooling_mode", "pooler_output"
73
+ )
74
+ tokenizer_kwargs = SENTENCE_EMB_LAYOUT[sentence_emb_type].get("tokenizer_kwargs", {})
75
+
76
+ self.sentence_emb_tokenizer = SENTENCE_EMB_LAYOUT[sentence_emb_type]["tokenizer_class"].from_pretrained(
77
+ SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"],
78
+ max_length=self.max_length_sentence_emb,
79
+ **tokenizer_kwargs,
80
+ )
81
+ self.sentence_emb_text_encoder = SENTENCE_EMB_LAYOUT[sentence_emb_type][
82
+ "text_encoder_class"
83
+ ].from_pretrained(SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"])
84
+ self.sentence_emb_text_encoder = self.sentence_emb_text_encoder.eval().requires_grad_(False)
85
+ self.vtxt_dim = self.sentence_emb_text_encoder.config.hidden_size
86
+
87
+ self.llm_type = llm_type
88
+ self.llm_text_encoder = None
89
+ self.llm_tokenizer = None
90
+ self.ctxt_dim = 0
91
+ self.crop_start = 0
92
+ self.max_length_llm = max_length_llm
93
+ if llm_type is not None:
94
+ assert llm_type in LLM_ENCODER_LAYOUT, f"Unsupported LLM type: {llm_type}"
95
+ self._orig_max_length_llm = max_length_llm
96
+ self.enable_llm_padding = enable_llm_padding
97
+ self.llm_tokenizer = LLM_ENCODER_LAYOUT[llm_type]["tokenizer_class"].from_pretrained(
98
+ LLM_ENCODER_LAYOUT[llm_type]["module_path"],
99
+ padding_side="right",
100
+ )
101
+ self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
102
+ LLM_ENCODER_LAYOUT[llm_type]["module_path"], low_cpu_mem_usage=True
103
+ )
104
+ self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
105
+ self.ctxt_dim = self.llm_text_encoder.config.hidden_size
106
+
107
+ self.crop_start = self._compute_crop_start()
108
+ self.max_length_llm = self._orig_max_length_llm + self.crop_start
109
+
110
+ @torch.no_grad()
111
+ def encode_llm(self, text: List[str]) -> Tuple[Tensor, Tensor]:
112
+ if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
113
+ raise ValueError("LLM model not initialized")
114
+
115
+ device = get_module_device(self)
116
+ llm_text = [
117
+ (
118
+ self.llm_tokenizer.apply_chat_template(
119
+ self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"]),
120
+ tokenize=False,
121
+ add_generation_prompt=False,
122
+ enable_thinking=False,
123
+ )
124
+ if self.llm_type == "qwen3"
125
+ else self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
126
+ )
127
+ for one_text in text
128
+ ]
129
+ padding_mode = "max_length" if self.enable_llm_padding else False
130
+ llm_batch_encoding = self.llm_tokenizer(
131
+ llm_text,
132
+ return_length=False,
133
+ return_overflowing_tokens=False,
134
+ truncation=True,
135
+ return_attention_mask=True,
136
+ max_length=self.max_length_llm, # = crop_start + _orig_max_length_llm
137
+ padding=padding_mode,
138
+ return_tensors="pt",
139
+ )
140
+ llm_outputs = (
141
+ self.llm_text_encoder(
142
+ input_ids=llm_batch_encoding["input_ids"].to(device),
143
+ attention_mask=llm_batch_encoding["attention_mask"].to(device),
144
+ output_hidden_states=True,
145
+ )
146
+ if self.llm_type == "qwen3"
147
+ else self.llm_text_encoder(
148
+ input_ids=llm_batch_encoding["input_ids"].to(device),
149
+ attention_mask=llm_batch_encoding["attention_mask"].to(device),
150
+ )
151
+ )
152
+ if self.llm_type == "qwen3":
153
+ ctxt_raw = llm_outputs.hidden_states[-1]
154
+ else:
155
+ ctxt_raw = llm_outputs.last_hidden_state
156
+
157
+ start = self.crop_start
158
+ end = start + self._orig_max_length_llm
159
+ ctxt_raw = ctxt_raw[:, start:end].contiguous() # [bs, _orig_max_length_llm, hidden]
160
+ ctxt_length = (llm_batch_encoding["attention_mask"].sum(dim=-1).to(device) - start).clamp(
161
+ min=0, max=self._orig_max_length_llm
162
+ )
163
+ return ctxt_raw, ctxt_length
164
+
165
+ @torch.no_grad()
166
+ def encode_sentence_emb(self, text: List[str]) -> Tensor:
167
+ if (
168
+ self.sentence_emb_type is None
169
+ or self.sentence_emb_text_encoder is None
170
+ or self.sentence_emb_tokenizer is None
171
+ ):
172
+ raise ValueError("Sentence embedding model not initialized")
173
+
174
+ device = get_module_device(self)
175
+ enc = self.sentence_emb_tokenizer(
176
+ text,
177
+ return_length=False,
178
+ return_overflowing_tokens=False,
179
+ truncation=True,
180
+ return_attention_mask=True,
181
+ max_length=self.max_length_sentence_emb,
182
+ padding=True,
183
+ return_tensors="pt",
184
+ )
185
+ out = self.sentence_emb_text_encoder(
186
+ input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)
187
+ )
188
+ if self._sentence_emb_pooling_mode == "pooler_output":
189
+ # Pooler output pooling (clip-vit-large-patch14 等)
190
+ if hasattr(out, "pooler_output") and out.pooler_output is not None:
191
+ vtxt_raw = out.pooler_output.unsqueeze(1)
192
+ else:
193
+ vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
194
+ elif self._sentence_emb_pooling_mode == "mean":
195
+ vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
196
+ elif self._sentence_emb_pooling_mode == "last_token":
197
+ vtxt_raw = self._last_token_pool(out.last_hidden_state, enc["attention_mask"].to(device))
198
+ else:
199
+ raise ValueError(f"Unknown pooling mode: {self._sentence_emb_pooling_mode}")
200
+
201
+ return vtxt_raw
202
+
203
+ def encode(self, text: List[str]) -> Tuple[Tensor, Tensor, Tensor]:
204
+ ctxt_raw, ctxt_length = self.encode_llm(text=text)
205
+ vtxt_raw = self.encode_sentence_emb(text=text)
206
+ return vtxt_raw, ctxt_raw, ctxt_length
207
+
208
+ @staticmethod
209
+ def apply_text_to_template(text: str, template: Union[str, list]) -> Union[str, list]:
210
+ if isinstance(template, str):
211
+ return template.format(text)
212
+ elif isinstance(template, list):
213
+ return [
214
+ {"role": "system", "content": f"{template[0]['content']}"},
215
+ {"role": "user", "content": f"{text}"},
216
+ ]
217
+ else:
218
+ raise TypeError(f"Unsupported template type: {type(template)}")
219
+
220
+ def _compute_crop_start(self) -> int:
221
+ if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
222
+ raise ValueError("LLM model not initialized")
223
+
224
+ def _find_subseq(a: str, b: str) -> int:
225
+ for i in range(0, len(a) - len(b) + 1):
226
+ if a[i : i + len(b)] == b:
227
+ return i
228
+ return -1
229
+
230
+ marker = "<BOC>"
231
+ if self.llm_type == "qwen3":
232
+ msgs = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
233
+ s = self.llm_tokenizer.apply_chat_template(
234
+ msgs, tokenize=False, add_generation_prompt=False, enable_thinking=False
235
+ )
236
+ else:
237
+ s = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
238
+ full_ids = self.llm_tokenizer(s, return_tensors="pt", add_special_tokens=True)["input_ids"][0].tolist()
239
+ marker_ids = self.llm_tokenizer(marker, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
240
+ pos = _find_subseq(full_ids, marker_ids)
241
+ if pos >= 0:
242
+ return pos
243
+ else:
244
+ return max(0, len(full_ids) - 1)
245
+
246
+ def _pad_or_truncate_tensor(self, tensor: Tensor, target_length: int, dim: int = 0) -> Tensor:
247
+ current_length = tensor.shape[dim]
248
+ if current_length > target_length:
249
+ return tensor.narrow(dim, 0, target_length)
250
+ elif current_length < target_length:
251
+ pad_shape = list(tensor.shape)
252
+ pad_shape[dim] = target_length - current_length
253
+ padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor.narrow(dim, -1, 1)
254
+ return torch.cat([tensor, padding], dim=dim)
255
+ return tensor
256
+
257
+ def _encode_pooling(self, attention_mask: Tensor, token_embeddings: Tensor) -> Tensor:
258
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
259
+ sentence_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
260
+ input_mask_expanded.sum(1), min=1e-9
261
+ )
262
+ vtxt_raw = nn.functional.normalize(sentence_embeddings, p=2, dim=1).unsqueeze(1) # shape of [bs, 1, D]
263
+ return vtxt_raw
264
+
265
+ def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
266
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
267
+ if left_padding:
268
+ vtxt_raw = last_hidden_states[:, -1]
269
+ else:
270
+ sequence_lengths = attention_mask.sum(dim=1) - 1
271
+ batch_size = last_hidden_states.shape[0]
272
+ vtxt_raw = last_hidden_states[
273
+ torch.arange(batch_size, device=last_hidden_states.device),
274
+ sequence_lengths,
275
+ ]
276
+ vtxt_raw = nn.functional.normalize(vtxt_raw, p=2, dim=-1).unsqueeze(1) # shape of [bs, 1, D]
277
+ return vtxt_raw
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # python -m hymotion.network.text_encoders.text_encoder
282
+ text_encoder = HYTextModel(llm_type="qwen3", max_length_llm=5)
283
+ vtxt_raw, ctxt_raw, ctxt_length = text_encoder.encode(["Hello, world!"])
284
+ print(vtxt_raw.shape, ctxt_raw.shape, ctxt_length)
285
+
286
+ crop_start = text_encoder._compute_crop_start()
287
+ print(f"crop_start: {crop_start} when using {text_encoder.llm_type}")
288
+
289
+ assert (
290
+ vtxt_raw.shape[1:] == (1, text_encoder.vtxt_dim)
291
+ and ctxt_raw.shape[1:] == (text_encoder._orig_max_length_llm, text_encoder.ctxt_dim)
292
+ and torch.all((ctxt_length >= 0) & (ctxt_length <= text_encoder._orig_max_length_llm))
293
+ ), f"Got unexpected output shape: {vtxt_raw.shape}, {ctxt_raw.shape}, {ctxt_length}"
hymotion/network/token_refiner.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from torch import Tensor
7
+
8
+ from .attention import attention
9
+ from .bricks import get_norm_layer
10
+ from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
11
+ from .modulate_layers import ModulateDiT, apply_gate
12
+
13
+
14
+ class IndividualTokenRefinerBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ feat_dim: int,
18
+ num_heads: int,
19
+ mlp_ratio: float = 4.0,
20
+ dropout: float = 0.0,
21
+ mlp_act_type: str = "silu",
22
+ qk_norm_type: str = "layer",
23
+ qkv_bias: bool = True,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.feat_dim = feat_dim
27
+ self.num_heads = num_heads
28
+ self.mlp_ratio = mlp_ratio
29
+ self.dropout = dropout
30
+ assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
31
+ self.head_dim = feat_dim // num_heads
32
+
33
+ self.mlp_hidden_dim = int(feat_dim * mlp_ratio)
34
+
35
+ self.norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
36
+ self.self_attn_qkv = nn.Linear(feat_dim, feat_dim * 3, bias=qkv_bias)
37
+ self.self_attn_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
38
+ self.self_attn_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
39
+ self.self_attn_proj = nn.Linear(feat_dim, feat_dim, bias=qkv_bias)
40
+
41
+ self.norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
42
+
43
+ self.mlp = MLP(
44
+ in_dim=feat_dim,
45
+ feat_dim=self.mlp_hidden_dim,
46
+ act_type=mlp_act_type,
47
+ drop=dropout,
48
+ )
49
+
50
+ self.adaLN_modulation = ModulateDiT(
51
+ feat_dim=feat_dim,
52
+ factor=2,
53
+ act_type="silu",
54
+ )
55
+
56
+ def forward(self, x: Tensor, c: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
57
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)
58
+ norm_x = self.norm1(x)
59
+ qkv = self.self_attn_qkv(norm_x)
60
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
61
+ # Apply QK-Norm if needed
62
+ q = self.self_attn_q_norm(q).to(v)
63
+ k = self.self_attn_k_norm(k).to(v)
64
+ # Self-Attention
65
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
66
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
67
+ # FFN Layer
68
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
69
+ return x
70
+
71
+
72
+ class IndividualTokenRefiner(nn.Module):
73
+ def __init__(
74
+ self,
75
+ feat_dim: int,
76
+ num_heads: int,
77
+ num_layers: int,
78
+ mlp_ratio: float = 4.0,
79
+ dropout: float = 0.0,
80
+ mlp_act_type: str = "silu",
81
+ qk_norm_type: str = "layer",
82
+ qkv_bias: bool = True,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.blocks = nn.ModuleList(
86
+ [
87
+ IndividualTokenRefinerBlock(
88
+ feat_dim=feat_dim,
89
+ num_heads=num_heads,
90
+ mlp_ratio=mlp_ratio,
91
+ dropout=dropout,
92
+ mlp_act_type=mlp_act_type,
93
+ qk_norm_type=qk_norm_type,
94
+ qkv_bias=qkv_bias,
95
+ )
96
+ for _ in range(num_layers)
97
+ ]
98
+ )
99
+
100
+ def forward(self, x: Tensor, c: Tensor, mask: Optional[Tensor] = None) -> Tensor:
101
+ self_attn_mask = None
102
+ if mask is not None:
103
+ batch_size = mask.shape[0]
104
+ seq_len = mask.shape[1]
105
+ mask = mask.to(x.device)
106
+ # batch_size x 1 x seq_len x seq_len
107
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
108
+ # batch_size x 1 x seq_len x seq_len
109
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
110
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
111
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
112
+ # avoids self-attention weight being NaN for padding tokens
113
+ # assume the shape of self_attn_mask is [B, H, Q, K] and this is self-attention (Q==K==L)
114
+ L = self_attn_mask.size(-1)
115
+ diag = torch.eye(L, dtype=torch.bool, device=self_attn_mask.device).view(1, 1, L, L) # [1,1,L,L]
116
+ # mark which query row is "all False" (no visible key)
117
+ all_false = ~self_attn_mask.any(dim=-1, keepdim=False) # [B, H, Q]
118
+ # expand to [B, H, Q, K], only for these rows, back to diagonal visible
119
+ all_false = all_false.unsqueeze(-1).expand(-1, -1, -1, L)
120
+ self_attn_mask = torch.where(all_false, diag.expand_as(self_attn_mask), self_attn_mask)
121
+
122
+ if self_attn_mask is not None:
123
+ self_attn_mask = torch.where(
124
+ self_attn_mask,
125
+ torch.zeros_like(self_attn_mask, dtype=torch.float),
126
+ torch.full_like(self_attn_mask, float("-inf"), dtype=torch.float),
127
+ )
128
+ for block in self.blocks:
129
+ x = block(x, c, self_attn_mask)
130
+ return x
131
+
132
+
133
+ class SingleTokenRefiner(nn.Module):
134
+ def __init__(
135
+ self,
136
+ input_dim: int,
137
+ feat_dim: int,
138
+ num_heads: int,
139
+ num_layers: int,
140
+ mlp_ratio: float = 4.0,
141
+ dropout: float = 0.0,
142
+ mlp_act_type: str = "silu",
143
+ qk_norm_type: str = "layer",
144
+ qkv_bias: bool = True,
145
+ attn_mode: str = "torch",
146
+ **kwargs,
147
+ ) -> None:
148
+ super().__init__()
149
+ self.attn_mode = attn_mode
150
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
151
+
152
+ self.input_embedder = nn.Linear(input_dim, feat_dim, bias=True)
153
+ self.context_encoder = MLPEncoder(
154
+ in_dim=feat_dim,
155
+ feat_dim=feat_dim,
156
+ num_layers=2,
157
+ act_type=mlp_act_type,
158
+ )
159
+ self.timestep_encoder = TimestepEmbeddingEncoder(
160
+ embedding_dim=feat_dim,
161
+ feat_dim=feat_dim,
162
+ act_type=mlp_act_type,
163
+ )
164
+
165
+ self.individual_token_refiner = IndividualTokenRefiner(
166
+ feat_dim=feat_dim,
167
+ num_heads=num_heads,
168
+ num_layers=num_layers,
169
+ mlp_ratio=mlp_ratio,
170
+ dropout=dropout,
171
+ mlp_act_type=mlp_act_type,
172
+ qk_norm_type=qk_norm_type,
173
+ qkv_bias=qkv_bias,
174
+ )
175
+
176
+ def forward(self, x: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Tensor:
177
+ timestep_aware_representations = self.timestep_encoder(t)
178
+
179
+ if mask is None:
180
+ context_aware_representations = x.mean(dim=1)
181
+ else:
182
+ mask_float = mask.float().unsqueeze(-1)
183
+ denom = mask_float.sum(dim=1).clamp_min(1e-6)
184
+ context_aware_representations = (x * mask_float).sum(dim=1) / denom
185
+ context_aware_representations = self.context_encoder(context_aware_representations).unsqueeze(1)
186
+ c = timestep_aware_representations + context_aware_representations
187
+
188
+ x = self.input_embedder(x)
189
+
190
+ x = self.individual_token_refiner(x, c, mask)
191
+
192
+ return x
hymotion/pipeline/body_model.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from ..utils.geometry import (
11
+ rot6d_to_rotation_matrix,
12
+ rotation_matrix_to_angle_axis,
13
+ )
14
+
15
+ # yapf: disable
16
+ LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909,
17
+ -0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192,
18
+ -0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069,
19
+ -0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579,
20
+ -0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998,
21
+ 0.0266, 0.0529, 0.5356, 0.0460, -0.2774]
22
+ RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909,
23
+ 0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192,
24
+ 0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069,
25
+ -0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579,
26
+ 0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998,
27
+ -0.0266, -0.0529, 0.5356, -0.0460, 0.2774]
28
+ # yapf: enable
29
+
30
+
31
+ def to_tensor(array, dtype=torch.float32, device=torch.device("cpu")):
32
+ if "torch.tensor" not in str(type(array)):
33
+ return torch.tensor(array, dtype=dtype).to(device)
34
+ else:
35
+ return array.to(device)
36
+
37
+
38
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
39
+ """Calculates the rotation matrices for a batch of rotation vectors
40
+ Parameters
41
+ ----------
42
+ rot_vecs: torch.tensor Nx3
43
+ array of N axis-angle vectors
44
+ Returns
45
+ -------
46
+ R: torch.tensor Nx3x3
47
+ The rotation matrices for the given axis-angle parameters
48
+ """
49
+ if len(rot_vecs.shape) > 2:
50
+ rot_vec_ori = rot_vecs
51
+ rot_vecs = rot_vecs.view(-1, 3)
52
+ else:
53
+ rot_vec_ori = None
54
+ batch_size = rot_vecs.shape[0]
55
+ device = rot_vecs.device
56
+
57
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
58
+ rot_dir = rot_vecs / angle
59
+
60
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
61
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
62
+
63
+ # Bx1 arrays
64
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
65
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
66
+
67
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
68
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
69
+
70
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
71
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
72
+ if rot_vec_ori is not None:
73
+ rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3)
74
+ return rot_mat
75
+
76
+
77
+ def load_model_data(model_path):
78
+ """
79
+ Load wooden model data from binary files.
80
+
81
+ Args:
82
+ model_path: path to the directory containing .bin files
83
+
84
+ Returns:
85
+ dict containing:
86
+ - v_template: (V, 3) vertex template
87
+ - j_template: (J, 3) joint template
88
+ - skin_weights: (V, 4) skin weights
89
+ - skin_indices: (V, 4) skin indices
90
+ - parents: (J,) parent indices (kintree)
91
+ - faces: (F, 3) face indices
92
+ - joint_names: list of joint names
93
+ """
94
+ model_path = Path(model_path)
95
+
96
+ # Load vertex template: (V*3,) -> (V, 3)
97
+ with open(model_path / "v_template.bin", "rb") as f:
98
+ v_template_flat = np.frombuffer(f.read(), dtype=np.float32)
99
+ num_verts = len(v_template_flat) // 3
100
+ v_template = v_template_flat.reshape(num_verts, 3)
101
+
102
+ # Load joint template: (J*3,) -> (J, 3)
103
+ with open(model_path / "j_template.bin", "rb") as f:
104
+ j_template_flat = np.frombuffer(f.read(), dtype=np.float32)
105
+ num_joints = len(j_template_flat) // 3
106
+ j_template = j_template_flat.reshape(num_joints, 3)
107
+
108
+ # Load skin weights: (V*4,) -> (V, 4), 4 bones per vertex
109
+ with open(model_path / "skinWeights.bin", "rb") as f:
110
+ skin_weights_flat = np.frombuffer(f.read(), dtype=np.float32)
111
+ skin_weights = skin_weights_flat.reshape(num_verts, 4)
112
+
113
+ # Load skin indices: (V*4,) -> (V, 4), 4 bone indices per vertex
114
+ with open(model_path / "skinIndice.bin", "rb") as f:
115
+ skin_indices_flat = np.frombuffer(f.read(), dtype=np.uint16)
116
+ skin_indices = skin_indices_flat.reshape(num_verts, 4).astype(np.int64)
117
+
118
+ # Load kintree (parent indices): (J,)
119
+ with open(model_path / "kintree.bin", "rb") as f:
120
+ parents = np.frombuffer(f.read(), dtype=np.int32)
121
+
122
+ # Load faces
123
+ with open(model_path / "faces.bin", "rb") as f:
124
+ faces_flat = np.frombuffer(f.read(), dtype=np.uint16)
125
+ faces = faces_flat.reshape(-1, 3)
126
+
127
+ # Load joint names
128
+ joint_names_path = model_path / "joint_names.json"
129
+ if joint_names_path.exists():
130
+ with open(joint_names_path, "r") as f:
131
+ joint_names = json.load(f)
132
+ else:
133
+ joint_names = [f"Joint_{i}" for i in range(num_joints)]
134
+
135
+ return {
136
+ "v_template": v_template,
137
+ "j_template": j_template,
138
+ "skin_weights": skin_weights,
139
+ "skin_indices": skin_indices,
140
+ "parents": parents,
141
+ "faces": faces,
142
+ "joint_names": joint_names,
143
+ "num_joints": num_joints,
144
+ "num_verts": num_verts,
145
+ }
146
+
147
+
148
+ def simple_lbs(v_template, rot_mats, joints, parents, skin_weights, skin_indices):
149
+ """
150
+ Simple Linear Blend Skinning without shape blending.
151
+
152
+ Args:
153
+ v_template: (V, 3) template vertices
154
+ rot_mats: (B, J, 3, 3) rotation matrices for each joint
155
+ joints: (J, 3) joint positions in rest pose
156
+ parents: (J,) parent indices for each joint
157
+ skin_weights: (V, 4) skin weights for 4 bones per vertex
158
+ skin_indices: (V, 4) bone indices for 4 bones per vertex
159
+
160
+ Returns:
161
+ vertices: (B, V, 3) transformed vertices
162
+ posed_joints: (B, J, 3) transformed joint positions
163
+ """
164
+ batch_size = rot_mats.shape[0]
165
+ num_joints = rot_mats.shape[1]
166
+ num_verts = v_template.shape[0]
167
+ device = rot_mats.device
168
+ dtype = rot_mats.dtype
169
+
170
+ # Compute relative joint positions
171
+ rel_joints = joints.clone()
172
+ rel_joints[1:] = joints[1:] - joints[parents[1:]]
173
+
174
+ # Build transformation chain: transforms_mat (B, J, 4, 4)
175
+ transforms_mat = torch.zeros(batch_size, num_joints, 4, 4, device=device, dtype=dtype)
176
+ transforms_mat[..., :3, :3] = rot_mats
177
+ transforms_mat[..., :3, 3] = rel_joints.unsqueeze(0).expand(batch_size, -1, -1)
178
+ transforms_mat[..., 3, 3] = 1.0
179
+
180
+ # Forward kinematics: accumulate transforms from root to each joint
181
+ transform_chain = [transforms_mat[:, 0]]
182
+ for i in range(1, num_joints):
183
+ parent_idx = parents[i].item()
184
+ curr_transform = torch.bmm(transform_chain[parent_idx], transforms_mat[:, i])
185
+ transform_chain.append(curr_transform)
186
+
187
+ transforms = torch.stack(transform_chain, dim=1) # (B, J, 4, 4)
188
+
189
+ # Get posed joint positions
190
+ posed_joints = transforms[..., :3, 3].clone() # (B, J, 3)
191
+
192
+ # Compute relative transforms (for skinning)
193
+ # We need to subtract the rest pose joint positions from the transform
194
+ rel_transforms = transforms.clone()
195
+ joints_homo = F.pad(joints, [0, 1], value=0) # (J, 4)
196
+ transformed_rest = torch.einsum("bjcd,jd->bjc", transforms[..., :3, :], joints_homo)
197
+ rel_transforms[..., :3, 3] = transforms[..., :3, 3] - transformed_rest[..., :3]
198
+
199
+ # Apply skinning: gather transforms for each vertex's 4 bones
200
+ # skin_indices: (V, 4), skin_weights: (V, 4)
201
+ vertex_transforms = torch.zeros(batch_size, num_verts, 4, 4, 4, device=device, dtype=dtype)
202
+ for k in range(4):
203
+ bone_idx = skin_indices[:, k].long() # (V,)
204
+ vertex_transforms[:, :, k] = rel_transforms[:, bone_idx] # (B, V, 4, 4)
205
+
206
+ # Weight the transforms
207
+ skin_weights_expanded = skin_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1, V, 4, 1, 1)
208
+ skin_weights_expanded = skin_weights_expanded.expand(batch_size, -1, -1, 4, 4) # (B, V, 4, 4, 4)
209
+
210
+ weighted_transforms = (vertex_transforms * skin_weights_expanded).sum(dim=2) # (B, V, 4, 4)
211
+
212
+ # Apply to vertices
213
+ v_homo = F.pad(v_template, [0, 1], value=1.0) # (V, 4)
214
+ vertices = torch.einsum("bvcd,vd->bvc", weighted_transforms[..., :3, :], v_homo) # (B, V, 3)
215
+
216
+ return vertices, posed_joints
217
+
218
+
219
+ class WoodenMesh(torch.nn.Module):
220
+ """
221
+ Wooden character mesh model that loads from binary files.
222
+ Uses simple LBS without shape blending (fixed skeleton).
223
+ """
224
+
225
+ def __init__(self, model_path="scripts/gradio/static/assets/dump_wooden"):
226
+ torch.nn.Module.__init__(self)
227
+
228
+ # Load model data from .bin files
229
+ model = load_model_data(model_path)
230
+
231
+ # Register buffers like original SMPLMesh
232
+ v_template = to_tensor(model["v_template"])
233
+ self.register_buffer("v_template", v_template)
234
+
235
+ j_template = to_tensor(model["j_template"])
236
+ self.register_buffer("j_template", j_template)
237
+
238
+ skin_weights = to_tensor(model["skin_weights"])
239
+ self.register_buffer("skin_weights", skin_weights)
240
+
241
+ skin_indices = to_tensor(model["skin_indices"], dtype=torch.long)
242
+ self.register_buffer("skin_indices", skin_indices)
243
+
244
+ parents = to_tensor(model["parents"], dtype=torch.long)
245
+ self.register_buffer("parents", parents)
246
+
247
+ # Store non-buffer attributes
248
+ self.faces = model["faces"]
249
+ self.joint_names = model["joint_names"]
250
+ self.num_joints = model["num_joints"]
251
+ self.num_verts = model["num_verts"]
252
+
253
+ print(f"[WoodenMesh] Loaded model: {self.num_verts} vertices, {self.num_joints} joints")
254
+
255
+ def forward(self, params, fast_forward=False):
256
+ """
257
+ Forward pass to compute deformed vertices.
258
+
259
+ Args:
260
+ params: dict containing:
261
+ - 'poses': (B, J*3) axis-angle rotations, or
262
+ - 'rot6d': (B, J, 6) 6D rotation representations
263
+ - 'trans': (B, 3) optional translation
264
+
265
+ Returns:
266
+ dict with 'vertices' and 'vertices_wotrans'
267
+ """
268
+ if "poses" in params:
269
+ poses = params["poses"]
270
+ batch_size = poses.shape[0]
271
+ rot_mats = batch_rodrigues(poses.view(-1, 3)).view([batch_size, -1, 3, 3])
272
+ elif "rot6d" in params:
273
+ rot6d = params["rot6d"]
274
+ batch_size = rot6d.shape[0]
275
+ rot_mats = rot6d_to_rotation_matrix(rot6d).view([batch_size, -1, 3, 3])
276
+ else:
277
+ raise ValueError("poses or rot6d must be in params")
278
+
279
+ if rot_mats.shape[1] == 22:
280
+ eye = torch.eye(3, device=rot_mats.device, dtype=rot_mats.dtype)[None, None, :, :].repeat(
281
+ batch_size, 30, 1, 1
282
+ )
283
+ rot_mats = torch.cat([rot_mats, eye], dim=1) # (B, 22 + 30, 3, 3)
284
+
285
+ # Simple LBS (no shape blending, fixed skeleton)
286
+ vertices, posed_joints = simple_lbs(
287
+ self.v_template,
288
+ rot_mats,
289
+ self.j_template,
290
+ self.parents,
291
+ self.skin_weights,
292
+ self.skin_indices,
293
+ )
294
+
295
+ # Vertices without translation (for pose-level supervision)
296
+ vertices_wotrans = vertices
297
+
298
+ if "trans" in params:
299
+ trans = params["trans"]
300
+ vertices = vertices + trans[:, None, :]
301
+
302
+ return {
303
+ "vertices": vertices,
304
+ "vertices_wotrans": vertices_wotrans,
305
+ "keypoints3d": posed_joints,
306
+ }
307
+
308
+ def forward_batch(self, params):
309
+ assert "rot6d" in params and "trans" in params
310
+ rot6d = params["rot6d"]
311
+ trans = params["trans"]
312
+ bs, num_frames = rot6d.shape[:2]
313
+ rot6d_flat = rot6d.reshape(bs * num_frames, rot6d.shape[2], rot6d.shape[3])
314
+ trans_flat = trans.reshape(bs * num_frames, trans.shape[2])
315
+ result = self.forward(
316
+ {
317
+ "rot6d": rot6d_flat,
318
+ "trans": trans_flat,
319
+ }
320
+ )
321
+ out = {}
322
+ for key in result:
323
+ out[key] = result[key].reshape(bs, num_frames, *result[key].shape[1:])
324
+ return out
325
+
326
+
327
+ def construct_smpl_data_dict(
328
+ rot6d: Tensor,
329
+ transl: Tensor,
330
+ betas: Optional[Tensor] = None,
331
+ gender: str = "neutral",
332
+ use_default_hand_mean_pose: bool = False,
333
+ ) -> dict:
334
+ rotation_matrix = rot6d_to_rotation_matrix(rot6d)
335
+ angle_axis = rotation_matrix_to_angle_axis(rotation_matrix)
336
+ left_hand_mean_pose = (
337
+ torch.tensor(
338
+ LEFT_HAND_MEAN_AA,
339
+ device=angle_axis.device,
340
+ dtype=angle_axis.dtype,
341
+ )
342
+ .unsqueeze(0)
343
+ .repeat(angle_axis.shape[0], 1)
344
+ .reshape(angle_axis.shape[0], -1, 3)
345
+ )
346
+ right_hand_mean_pose = (
347
+ torch.tensor(
348
+ RIGHT_HAND_MEAN_AA,
349
+ device=angle_axis.device,
350
+ dtype=angle_axis.dtype,
351
+ )
352
+ .unsqueeze(0)
353
+ .repeat(angle_axis.shape[0], 1)
354
+ .reshape(angle_axis.shape[0], -1, 3)
355
+ )
356
+ if angle_axis.shape[1] == 22:
357
+ angle_axis = torch.cat(
358
+ [
359
+ angle_axis,
360
+ left_hand_mean_pose,
361
+ right_hand_mean_pose,
362
+ ],
363
+ dim=1,
364
+ )
365
+ elif angle_axis.shape[1] == 52:
366
+ if use_default_hand_mean_pose:
367
+ angle_axis = torch.cat(
368
+ [
369
+ angle_axis[:, :22],
370
+ left_hand_mean_pose,
371
+ right_hand_mean_pose,
372
+ ],
373
+ dim=1,
374
+ )
375
+ else:
376
+ angle_axis = angle_axis
377
+
378
+ assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}"
379
+ dump = {
380
+ "betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)),
381
+ "gender": gender,
382
+ "poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1),
383
+ "trans": transl.cpu().numpy(),
384
+ "mocap_framerate": 30,
385
+ "num_frames": angle_axis.shape[0],
386
+ "Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3],
387
+ }
388
+ return dump
389
+
390
+
391
+ if __name__ == "__main__":
392
+ # python -m hymotion.pipeline.body_model
393
+ model_path = "scripts/gradio/static/assets/dump_wooden"
394
+ model = WoodenMesh(model_path)
395
+ params = {
396
+ "rot6d": torch.randn(1, 52, 6),
397
+ "trans": torch.randn(1, 3),
398
+ }
399
+ result = model(params)
400
+ print(result.keys())
401
+ print(result["vertices"].shape)
402
+ print(result["vertices_wotrans"].shape)
403
+ print(result["keypoints3d"].shape)
404
+ params_batch = {
405
+ "rot6d": torch.randn(3, 100, 22, 6),
406
+ "trans": torch.randn(3, 100, 3),
407
+ }
408
+ result_batch = model.forward_batch(params_batch)
409
+ print(result_batch.keys())
410
+ print(result_batch["vertices"].shape)
411
+ print(result_batch["vertices_wotrans"].shape)
412
+ print(result_batch["keypoints3d"].shape)
hymotion/pipeline/motion_diffusion.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from copy import deepcopy
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from scipy.signal import savgol_filter
9
+ from torch import Tensor
10
+ from torchdiffeq import odeint
11
+
12
+ from ..utils.geometry import (
13
+ matrix_to_quaternion,
14
+ quaternion_fix_continuity,
15
+ quaternion_to_matrix,
16
+ rot6d_to_rotation_matrix,
17
+ rotation_matrix_to_rot6d,
18
+ )
19
+ from ..utils.loaders import load_object
20
+ from ..utils.motion_process import smooth_rotation
21
+ from ..utils.type_converter import get_module_device
22
+ from .body_model import WoodenMesh
23
+
24
+
25
+ def length_to_mask(lengths: Tensor, max_len: int) -> Tensor:
26
+ """
27
+ lengths: (B, 1)
28
+ max_len: int
29
+ Returns: (B, max_len)
30
+ """
31
+ assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
32
+ if lengths.ndim == 1:
33
+ lengths = lengths.unsqueeze(1)
34
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths
35
+ return mask
36
+
37
+
38
+ def start_end_frame_to_mask(start_frame: Tensor, end_frame: Tensor, max_len: int) -> Tensor:
39
+ assert (start_frame >= 0).all() and (end_frame >= 0).all(), f"start_frame={start_frame}, end_frame={end_frame}"
40
+ lengths = end_frame - start_frame + 1
41
+ assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
42
+ if lengths.ndim == 1:
43
+ lengths = lengths.unsqueeze(1)
44
+ batch_size = start_frame.shape[0]
45
+ arange_ids = torch.arange(max_len, device=start_frame.device).unsqueeze(0).expand(batch_size, max_len)
46
+ mask = (arange_ids >= start_frame.unsqueeze(1)) & (arange_ids <= end_frame.unsqueeze(1))
47
+ return mask
48
+
49
+
50
+ def randn_tensor(
51
+ shape,
52
+ generator=None,
53
+ device=None,
54
+ dtype=None,
55
+ layout=None,
56
+ ):
57
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`.
58
+
59
+ When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the
60
+ tensor is always created on the CPU.
61
+ """
62
+ # device on which tensor is created defaults to device
63
+ rand_device = device
64
+ batch_size = shape[0]
65
+
66
+ layout = layout or torch.strided
67
+ device = device or torch.device("cpu")
68
+
69
+ if generator is not None:
70
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
71
+ if gen_device_type != device.type and gen_device_type == "cpu":
72
+ rand_device = "cpu"
73
+ if device != "mps":
74
+ print(
75
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
76
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
77
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
78
+ )
79
+ elif gen_device_type != device.type and gen_device_type == "cuda":
80
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
81
+
82
+ # make sure generator list of length 1 is treated like a non-list
83
+ if isinstance(generator, list) and len(generator) == 1:
84
+ generator = generator[0]
85
+
86
+ if isinstance(generator, list):
87
+ shape = (1,) + shape[1:]
88
+ latents = [
89
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
90
+ for i in range(batch_size)
91
+ ]
92
+ latents = torch.cat(latents, dim=0).to(device)
93
+ else:
94
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
95
+
96
+ return latents
97
+
98
+
99
+ class MotionGeneration(torch.nn.Module):
100
+ def __init__(
101
+ self,
102
+ network_module: str,
103
+ network_module_args: dict,
104
+ text_encoder_module: str,
105
+ text_encoder_cfg: dict,
106
+ mean_std_dir: str,
107
+ motion_type="auto",
108
+ **kwargs,
109
+ ):
110
+ super().__init__()
111
+ # build models and parameters
112
+ self._network_module_args = deepcopy(network_module_args)
113
+ self.motion_transformer = load_object(network_module, network_module_args)
114
+ self._text_encoder_module = text_encoder_module
115
+ self._text_encoder_cfg = deepcopy(text_encoder_cfg)
116
+ self.motion_type = motion_type
117
+
118
+ self.null_vtxt_feat = torch.nn.Parameter(
119
+ torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
120
+ )
121
+ self.null_ctxt_input = torch.nn.Parameter(
122
+ torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
123
+ )
124
+ self.special_game_vtxt_feat = torch.nn.Parameter(
125
+ torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
126
+ )
127
+ self.special_game_ctxt_feat = torch.nn.Parameter(
128
+ torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
129
+ )
130
+ # build buffer
131
+ self.mean_std_dir = mean_std_dir
132
+ self._parse_buffer(self.motion_type)
133
+
134
+ self.output_mesh_fps = kwargs.get("output_mesh_fps", 30)
135
+ self.train_frames = kwargs.get("train_frames", 360)
136
+ self.uncondition_mode = kwargs.get("uncondition_mode", False)
137
+ self.enable_ctxt_null_feat = kwargs.get("enable_ctxt_null_feat", False)
138
+ self.enable_special_game_feat = kwargs.get("enable_special_game_feat", False)
139
+ self.random_generator_on_gpu = kwargs.get("random_generator_on_gpu", True)
140
+
141
+ def _parse_buffer(self, mode: str) -> None:
142
+ self.body_model = WoodenMesh()
143
+ self._find_motion_type(mode=mode)
144
+ self._load_mean_std()
145
+
146
+ def _load_mean_std(self, mean_std_name: Optional[str] = None) -> None:
147
+ mean_std_name = self.mean_std_dir if mean_std_name is None else mean_std_name
148
+ if mean_std_name is not None and osp.isdir(mean_std_name):
149
+ mean = torch.from_numpy(np.load(osp.join(mean_std_name, "Mean.npy"))).float()
150
+ std = torch.from_numpy(np.load(osp.join(mean_std_name, "Std.npy"))).float()
151
+ self._assert_motion_dimension(mean.unsqueeze(0), std.unsqueeze(0))
152
+ self.register_buffer("mean", mean)
153
+ self.register_buffer("std", std)
154
+ else:
155
+ print(
156
+ f"[{self.__class__.__name__}] No mean_std found, using blank mean_std, "
157
+ f"self.mean_std_dir={self.mean_std_dir}"
158
+ )
159
+ self.register_buffer("mean", torch.zeros(1))
160
+ self.register_buffer("std", torch.ones(1))
161
+
162
+ def _assert_motion_dimension(self, mean: Tensor, std: Tensor) -> None:
163
+ assert mean.shape == std.shape, f"mean.shape={mean.shape} != std.shape={std.shape}"
164
+ assert mean.ndim == 2, f"mean.ndim={mean.ndim} != 2"
165
+ assert mean.shape == (1, 201), f"mean.shape={mean.shape} != (1, 201)"
166
+
167
+ def _find_motion_type(self, mode: str) -> None:
168
+ if mode == "auto":
169
+ self.motion_type = "o6dp"
170
+ else:
171
+ self.motion_type = mode
172
+
173
+ def set_epoch(self, epoch) -> None:
174
+ self.current_epoch = epoch
175
+
176
+ def load_in_demo(
177
+ self,
178
+ ckpt_name: str,
179
+ mean_std_name: Optional[str] = None,
180
+ build_text_encoder: bool = True,
181
+ allow_empty_ckpt: bool = False,
182
+ ) -> None:
183
+ if not allow_empty_ckpt:
184
+ if not os.path.exists(ckpt_name):
185
+ import warnings
186
+
187
+ warnings.warn(f"Checkpoint {ckpt_name} not found, skipping model loading")
188
+ else:
189
+ checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
190
+ self.load_state_dict(checkpoint["model_state_dict"], strict=False)
191
+ if mean_std_name is not None:
192
+ assert os.path.exists(mean_std_name), f"{mean_std_name} not found"
193
+ if not os.path.isfile(mean_std_name):
194
+ mean_std_name = None
195
+ self._load_mean_std(mean_std_name)
196
+ self.motion_transformer.eval()
197
+ if build_text_encoder and not self.uncondition_mode:
198
+ self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
199
+ self.text_encoder.to(get_module_device(self))
200
+
201
+ @torch.no_grad()
202
+ def encode_text(self, text: Dict[str, List[str]]) -> Dict[str, Tensor]:
203
+ if not hasattr(self, "text_encoder"):
204
+ self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
205
+ self.text_encoder.to(get_module_device(self))
206
+ text = text["text"]
207
+ vtxt_input, ctxt_input, ctxt_length = self.text_encoder.encode(text=text)
208
+ return {
209
+ "text_vec_raw": vtxt_input,
210
+ "text_ctxt_raw": ctxt_input,
211
+ "text_ctxt_raw_length": ctxt_length,
212
+ }
213
+
214
+ def decode_motion_from_latent(self, latent: Tensor, should_apply_smooothing: bool = True) -> Dict[str, Tensor]:
215
+ std_zero = self.std < 1e-3
216
+ std = torch.where(std_zero, torch.zeros_like(self.std), self.std)
217
+ latent_denorm = latent * std + self.mean
218
+ return self._decode_o6dp(
219
+ latent_denorm,
220
+ num_joints=22,
221
+ rel_trans=False,
222
+ should_apply_smooothing=should_apply_smooothing,
223
+ )
224
+
225
+ def _decode_o6dp(
226
+ self,
227
+ latent_denorm: torch.Tensor,
228
+ num_joints: int,
229
+ rel_trans: bool = False,
230
+ should_apply_smooothing: bool = True,
231
+ ) -> dict:
232
+ device = get_module_device(self)
233
+ B, L = latent_denorm.shape[:2]
234
+ nj = num_joints
235
+ body_n = nj - 1
236
+
237
+ if not rel_trans:
238
+ transl = latent_denorm[..., 0:3].clone()
239
+ else:
240
+ transl = torch.cumsum(latent_denorm[..., 0:3].clone(), dim=1) / self.output_mesh_fps
241
+ root_rot6d = latent_denorm[..., 3:9].reshape(B, L, 1, 6).clone()
242
+
243
+ body6d_start = 9
244
+ body6d_end = body6d_start + body_n * 6
245
+ body_rot6d_full = latent_denorm[..., body6d_start:body6d_end].clone().reshape(B, L, body_n, 6)
246
+
247
+ # 52 joints need to be split into hands
248
+ left_hand_pose = right_hand_pose = None
249
+ if nj == 52:
250
+ body_rot6d = body_rot6d_full[:, :, :21, :].clone()
251
+ left_hand_pose = body_rot6d_full[:, :, 21:36, :].clone()
252
+ right_hand_pose = body_rot6d_full[:, :, 36:51, :].clone()
253
+ else:
254
+ body_rot6d = body_rot6d_full
255
+
256
+ if left_hand_pose is not None and right_hand_pose is not None:
257
+ body_full = torch.cat([body_rot6d, left_hand_pose, right_hand_pose], dim=2)
258
+ else:
259
+ body_full = body_rot6d
260
+ rot6d = torch.cat([root_rot6d, body_full], dim=2) # (B, L, nj, 6)
261
+ if should_apply_smooothing:
262
+ # only apply slerp smoothing to the first 22 joints (non-finger joints)
263
+ rot6d_body = rot6d[:, :, :22, :] # (B, L, 22, 6)
264
+ rot6d_fingers = rot6d[:, :, 22:, :] # (B, L, J-22, 6)
265
+ rot6d_body_smooth = self.smooth_with_slerp(rot6d_body, sigma=1.0)
266
+ rot6d_smooth = torch.cat([rot6d_body_smooth, rot6d_fingers], dim=2)
267
+ else:
268
+ rot6d_smooth = rot6d
269
+ root_rotmat_smooth = rot6d_to_rotation_matrix(rot6d_smooth[:, :, 0, :]) # (B, L, 3, 3)
270
+
271
+ transl_fixed = transl.detach()
272
+ if should_apply_smooothing:
273
+ transl_smooth = self.smooth_with_savgol(transl_fixed.detach(), window_length=11, polyorder=5)
274
+ else:
275
+ transl_smooth = transl_fixed
276
+
277
+ if self.body_model is not None:
278
+ print(
279
+ f"{self.__class__.__name__} rot6d_smooth shape: {rot6d_smooth.shape}, transl_smooth shape: {transl_smooth.shape}"
280
+ )
281
+ with torch.no_grad():
282
+ vertices_all = []
283
+ k3d_all = []
284
+ for bs in range(rot6d_smooth.shape[0]):
285
+ out = self.body_model.forward({"rot6d": rot6d_smooth[bs], "trans": transl_smooth[bs]})
286
+ vertices_all.append(out["vertices"])
287
+ k3d_all.append(out["keypoints3d"])
288
+ vertices = torch.stack(vertices_all, dim=0)
289
+ k3d = torch.stack(k3d_all, dim=0)
290
+ print(f"{self.__class__.__name__} vertices shape: {vertices.shape}, k3d shape: {k3d.shape}")
291
+ # align with the ground
292
+ min_y = vertices[..., 1].amin(dim=(1, 2), keepdim=True) # (B, 1, 1)
293
+ print(f"{self.__class__.__name__} min_y: {min_y}")
294
+ k3d = k3d.clone()
295
+ k3d[..., 1] -= min_y # (B, L, J) - (B, 1, 1)
296
+ transl_smooth = transl_smooth.clone()
297
+ transl_smooth[..., 1] -= min_y.squeeze(-1).to(device) # (B, L) - (B, 1)
298
+ else:
299
+ k3d = torch.zeros(B, L, nj, 3, device=device)
300
+
301
+ return dict(
302
+ latent_denorm=latent_denorm, # (B, L, 201)
303
+ keypoints3d=k3d, # (B, L, J, 3)
304
+ rot6d=rot6d_smooth, # (B, L, J, 6)
305
+ transl=transl_smooth, # (B, L, 3)
306
+ root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3)
307
+ )
308
+
309
+ @staticmethod
310
+ def smooth_with_savgol(input: torch.Tensor, window_length: int = 9, polyorder: int = 5) -> torch.Tensor:
311
+ if len(input.shape) == 2:
312
+ is_batch = False
313
+ input = input.unsqueeze(0)
314
+ else:
315
+ is_batch = True
316
+ input_np = input.cpu().numpy()
317
+ input_smooth_np = np.empty_like(input_np, dtype=np.float32)
318
+ for b in range(input_np.shape[0]):
319
+ for j in range(input_np.shape[2]):
320
+ input_smooth_np[b, :, j] = savgol_filter(input_np[b, :, j], window_length, polyorder)
321
+ input_smooth = torch.from_numpy(input_smooth_np).to(input)
322
+ if not is_batch:
323
+ input_smooth = input_smooth.squeeze(0)
324
+ return input_smooth
325
+
326
+ @staticmethod
327
+ def smooth_with_slerp(input: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
328
+ def fix_time_continuity(q: Tensor, time_dim: int = -3):
329
+ shape = q.shape
330
+ qv = q.moveaxis(time_dim, 0).contiguous().view(shape[time_dim], -1, 4)
331
+ qv = quaternion_fix_continuity(qv)
332
+ return qv.view(shape[time_dim], *shape[:time_dim], *shape[time_dim + 1 :]).moveaxis(0, time_dim)
333
+
334
+ num_joints = input.shape[2]
335
+ RR = rot6d_to_rotation_matrix(input)
336
+ qq = matrix_to_quaternion(RR)
337
+ qq_np = fix_time_continuity(qq, time_dim=1).cpu().numpy()
338
+ qq_s_np = smooth_rotation(
339
+ qq_np,
340
+ sigma=sigma,
341
+ )
342
+ input_smooth = rotation_matrix_to_rot6d(quaternion_to_matrix(torch.from_numpy(qq_s_np)))
343
+ return input_smooth.to(input.device)
344
+
345
+ @staticmethod
346
+ def noise_from_seeds(
347
+ latent: Tensor, seeds: Union[int, List[int]], seed_start: int = 0, random_generator_on_gpu: bool = True
348
+ ) -> Tensor:
349
+ if isinstance(seeds, int):
350
+ seeds = list(range(seeds))
351
+ noise_list = []
352
+ B = latent.shape[0]
353
+ shape = (B, *latent.shape[1:])
354
+ for seed in seeds:
355
+ if random_generator_on_gpu:
356
+ generator = torch.Generator(device=latent.device).manual_seed(seed + seed_start)
357
+ noise_sample = randn_tensor(shape, generator=generator, device=latent.device, dtype=latent.dtype)
358
+ else:
359
+ generator = torch.Generator().manual_seed(seed + seed_start)
360
+ noise_sample = randn_tensor(shape, generator=generator, dtype=latent.dtype).to(latent.device)
361
+ noise_list.append(noise_sample)
362
+ return torch.cat(noise_list, dim=0)
363
+
364
+ def _maybe_inject_source_token(
365
+ self,
366
+ vtxt_input: Tensor,
367
+ ctxt_input: Tensor,
368
+ ctxt_mask_temporal: Tensor,
369
+ sources: Optional[List[str]],
370
+ trigger_sources: Optional[set] = None,
371
+ prob: float = 0.5,
372
+ ) -> Tuple[Tensor, Tensor, Tensor]:
373
+ if (sources is None or trigger_sources is None) or not self.enable_special_game_feat:
374
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
375
+
376
+ B, Lc, Dc = ctxt_input.shape
377
+ assert (
378
+ isinstance(sources, (list, tuple)) and len(sources) == B
379
+ ), f"sources length should be equal to batch: {len(sources)} vs {B}"
380
+
381
+ trig = set(s.lower() for s in trigger_sources)
382
+ src_mask = torch.tensor(
383
+ [str(s).lower() in trig for s in sources], dtype=torch.bool, device=ctxt_input.device
384
+ ) # (B,)
385
+ if not src_mask.any():
386
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
387
+
388
+ rand_mask = (
389
+ torch.rand(B, device=ctxt_input.device) < prob
390
+ if self.training
391
+ else torch.BoolTensor(B).fill_(True).to(ctxt_input.device)
392
+ )
393
+ apply_mask = src_mask & rand_mask
394
+ if not apply_mask.any():
395
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
396
+
397
+ # vtxt: only add mixture to the hit samples
398
+ vtxt_token = self.special_game_vtxt_feat.to(vtxt_input).expand(B, 1, -1)
399
+ vtxt_input = vtxt_input + vtxt_token * apply_mask.view(B, 1, 1).to(vtxt_input.dtype)
400
+
401
+ # calculate the current effective length of each sample
402
+ if ctxt_mask_temporal.dtype == torch.bool:
403
+ cur_len = ctxt_mask_temporal.sum(dim=1).long() # (B,)
404
+ else:
405
+ cur_len = (ctxt_mask_temporal > 0).sum(dim=1).long()
406
+
407
+ # for the "not full" hit samples,
408
+ # write the special token at the cur_len position,
409
+ # and set the mask to True
410
+ can_inplace = apply_mask & (cur_len < Lc)
411
+ b_inplace = torch.nonzero(can_inplace, as_tuple=False).squeeze(1) # (K,)
412
+ if b_inplace.numel() > 0:
413
+ pos = cur_len[b_inplace] # (K,)
414
+ token = self.special_game_ctxt_feat.squeeze(0).squeeze(0).to(ctxt_input) # (Dc,)
415
+ ctxt_input[b_inplace, pos, :] = token.unsqueeze(0).expand(b_inplace.numel(), Dc)
416
+ if ctxt_mask_temporal.dtype == torch.bool:
417
+ ctxt_mask_temporal[b_inplace, pos] = True
418
+ else:
419
+ ctxt_mask_temporal[b_inplace, pos] = 1
420
+
421
+ # if there are "full" hit samples, need to pad one:
422
+ # the full samples write the special token at the new position,
423
+ # other samples pad zero and mask=False
424
+ need_expand = (apply_mask & (cur_len >= Lc)).any()
425
+ if need_expand:
426
+ suffix = torch.zeros((B, 1, Dc), dtype=ctxt_input.dtype, device=ctxt_input.device)
427
+ full_hit = apply_mask & (cur_len >= Lc)
428
+ b_full = torch.nonzero(full_hit, as_tuple=False).squeeze(1)
429
+ if b_full.numel() > 0:
430
+ suffix[b_full, 0, :] = (
431
+ self.special_game_ctxt_feat.expand(b_full.numel(), 1, -1).to(ctxt_input).squeeze(1)
432
+ )
433
+ ctxt_input = torch.cat([ctxt_input, suffix], dim=1)
434
+
435
+ if ctxt_mask_temporal.dtype == torch.bool:
436
+ suffix_mask = torch.zeros((B, 1), dtype=torch.bool, device=ctxt_input.device)
437
+ suffix_mask[b_full, 0] = True
438
+ else:
439
+ suffix_mask = torch.zeros((B, 1), dtype=ctxt_mask_temporal.dtype, device=ctxt_input.device)
440
+ suffix_mask[b_full, 0] = 1
441
+ ctxt_mask_temporal = torch.cat([ctxt_mask_temporal, suffix_mask], dim=1)
442
+
443
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
444
+
445
+
446
+ class MotionFlowMatching(MotionGeneration):
447
+ def __init__(
448
+ self,
449
+ network_module: str,
450
+ network_module_args: dict,
451
+ text_encoder_module: str,
452
+ text_encoder_cfg: dict,
453
+ noise_scheduler_cfg: dict = {"method": "euler"},
454
+ infer_noise_scheduler_cfg: dict = {"validation_steps": 50},
455
+ mean_std_dir: Optional[str] = None,
456
+ losses_cfg: Optional[dict] = None,
457
+ train_cfg: Optional[dict] = None,
458
+ test_cfg: Optional[dict] = None,
459
+ **kwargs,
460
+ ):
461
+ super().__init__(
462
+ network_module=network_module,
463
+ network_module_args=network_module_args,
464
+ text_encoder_module=text_encoder_module,
465
+ text_encoder_cfg=text_encoder_cfg,
466
+ losses_cfg=losses_cfg,
467
+ mean_std_dir=(mean_std_dir if mean_std_dir is not None else test_cfg.get("mean_std_dir", None)),
468
+ **kwargs,
469
+ )
470
+ # build scheduler
471
+ self._noise_scheduler_cfg = deepcopy(noise_scheduler_cfg)
472
+ self._infer_noise_scheduler_cfg = deepcopy(infer_noise_scheduler_cfg)
473
+ # additional cfg
474
+ self.train_cfg = deepcopy(train_cfg) if train_cfg else dict()
475
+ self.test_cfg = deepcopy(test_cfg) if test_cfg else dict()
476
+ self._parse_test_cfg()
477
+
478
+ def _parse_test_cfg(self) -> None:
479
+ self.validation_steps = self._infer_noise_scheduler_cfg["validation_steps"]
480
+ self.text_guidance_scale = self.test_cfg.get("text_guidance_scale", 1)
481
+
482
+ @torch.no_grad()
483
+ def generate(
484
+ self,
485
+ text: Union[str, List[str]],
486
+ seed_input: List[int],
487
+ duration_slider: int,
488
+ cfg_scale: Optional[float] = None,
489
+ use_special_game_feat: bool = False,
490
+ hidden_state_dict=None,
491
+ length=None,
492
+ ) -> Dict[str, Any]:
493
+ device = get_module_device(self)
494
+ if length is None:
495
+ length = int(round(duration_slider * self.output_mesh_fps))
496
+ assert (
497
+ 0 < length < 5000
498
+ ), f"input duration_slider must be in (0, {5000/self.output_mesh_fps}] due to rope, but got {duration_slider}"
499
+ if length > self.train_frames or length < min(self.train_frames, 20):
500
+ print(f">>> given length is too long or too short, got {length}, will be truncated")
501
+ length = min(length, self.train_frames)
502
+ length = max(length, min(self.train_frames, 20))
503
+
504
+ repeat = len(seed_input)
505
+ if isinstance(text, list):
506
+ assert len(text) == repeat, f"len(text) must equal len(seed_input), got {len(text)} vs {repeat}"
507
+ text_list = text
508
+ elif isinstance(text, str):
509
+ text_list = [text] * repeat
510
+ else:
511
+ raise TypeError(f"Unsupported text type: {type(text)}")
512
+
513
+ if not self.uncondition_mode:
514
+ if hidden_state_dict is None:
515
+ hidden_state_dict = self.encode_text({"text": text_list})
516
+ vtxt_input = hidden_state_dict["text_vec_raw"]
517
+ ctxt_input = hidden_state_dict["text_ctxt_raw"]
518
+ ctxt_length = hidden_state_dict["text_ctxt_raw_length"]
519
+ # check shape
520
+ if len(vtxt_input.shape) == 2 and len(ctxt_input.shape) == 2:
521
+ vtxt_input = vtxt_input[None].repeat(repeat, 1, 1)
522
+ ctxt_input = ctxt_input[None].repeat(repeat, 1, 1)
523
+ ctxt_length = ctxt_length.repeat(repeat)
524
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
525
+ sources = None if not use_special_game_feat else ["Game"] * repeat
526
+ vtxt_input, ctxt_input, ctxt_mask_temporal = self._maybe_inject_source_token(
527
+ vtxt_input, ctxt_input, ctxt_mask_temporal, sources, trigger_sources={"Taobao", "Game"}
528
+ )
529
+ else:
530
+ vtxt_input = self.null_vtxt_feat.expand(repeat, 1, -1)
531
+ ctxt_input = self.null_ctxt_input.expand(repeat, 1, -1)
532
+ ctxt_length = torch.tensor([1]).expand(repeat)
533
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]).expand(repeat, -1)
534
+ assert len(vtxt_input.shape) == 3, f"vtxt_input.shape: {vtxt_input.shape}, should be (B, 1, D)"
535
+ assert len(ctxt_input.shape) == 3, f"ctxt_input.shape: {ctxt_input.shape}, should be (B, 1, D)"
536
+ assert len(ctxt_length.shape) == 1, f"ctxt_length.shape: {ctxt_length.shape}, should be (B,)"
537
+
538
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
539
+ x_length = torch.LongTensor([length] * repeat).to(device)
540
+ x_mask_temporal = length_to_mask(x_length, self.train_frames)
541
+
542
+ text_guidance_scale = cfg_scale if cfg_scale is not None else self.text_guidance_scale
543
+ do_classifier_free_guidance = text_guidance_scale > 1.0 and not self.uncondition_mode
544
+ if do_classifier_free_guidance is True:
545
+ silent_text_feat = self.null_vtxt_feat.expand(*vtxt_input.shape)
546
+ vtxt_input = torch.cat([silent_text_feat, vtxt_input], dim=0)
547
+
548
+ if self.enable_ctxt_null_feat:
549
+ silent_ctxt_input = self.null_ctxt_input.expand(*ctxt_input.shape)
550
+ else:
551
+ silent_ctxt_input = ctxt_input
552
+ ctxt_input = torch.cat([silent_ctxt_input, ctxt_input], dim=0)
553
+
554
+ ctxt_mask_temporal = torch.cat([ctxt_mask_temporal] * 2, dim=0)
555
+ x_mask_temporal = torch.cat([x_mask_temporal] * 2, dim=0)
556
+
557
+ def fn(t: Tensor, x: Tensor) -> Tensor:
558
+ # predict flow
559
+ x_input = torch.cat([x] * 2, dim=0) if do_classifier_free_guidance else x
560
+ x_pred = self.motion_transformer(
561
+ x=x_input,
562
+ ctxt_input=ctxt_input,
563
+ vtxt_input=vtxt_input,
564
+ timesteps=t.expand(x_input.shape[0]),
565
+ x_mask_temporal=x_mask_temporal,
566
+ ctxt_mask_temporal=ctxt_mask_temporal,
567
+ )
568
+ if do_classifier_free_guidance:
569
+ x_pred_basic, x_pred_text = x_pred.chunk(2, dim=0)
570
+ x_pred = x_pred_basic + text_guidance_scale * (x_pred_text - x_pred_basic)
571
+ return x_pred
572
+
573
+ # duplicate test corner for inner time step oberservation
574
+ t = torch.linspace(0, 1, self.validation_steps + 1, device=device)
575
+ y0 = self.noise_from_seeds(
576
+ torch.zeros(
577
+ 1,
578
+ self.train_frames,
579
+ self._network_module_args["input_dim"],
580
+ device=device,
581
+ ),
582
+ seed_input,
583
+ random_generator_on_gpu=self.random_generator_on_gpu,
584
+ )
585
+ with torch.no_grad():
586
+ trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
587
+ sampled = trajectory[-1]
588
+ assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
589
+ sampled = sampled[:, :length, ...].clone()
590
+
591
+ output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
592
+
593
+ return {
594
+ **output_dict,
595
+ "text": text,
596
+ }
597
+
598
+
599
+ if __name__ == "__main__":
600
+ # python -m hymotion.pipeline.motion_diffusion
601
+ import time
602
+
603
+ import torch
604
+
605
+ device = "cuda:0"
606
+ bsz, input_dim = 64, 272
607
+ seq_lens = [90, 180, 360]
608
+ ctxt_seq_lens = 64
609
+ warmup = 5
610
+ repeats = 100
611
+
612
+ network_module = "hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT"
613
+ network_module_args = {
614
+ "input_dim": input_dim,
615
+ "feat_dim": 512,
616
+ "ctxt_input_dim": 4096,
617
+ "vtxt_input_dim": 768,
618
+ "num_layers": 12,
619
+ "num_heads": 4,
620
+ "mlp_ratio": 2.0,
621
+ "dropout": 0.0,
622
+ "mask_mode": "narrowband",
623
+ }
624
+ text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
625
+ text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": ctxt_seq_lens}
626
+
627
+ # ================================ FM_MMDiT ================================
628
+ FM_MMDiT = MotionFlowMatching(
629
+ network_module=network_module,
630
+ network_module_args=network_module_args,
631
+ text_encoder_module=text_encoder_module,
632
+ text_encoder_cfg=text_encoder_cfg,
633
+ noise_scheduler_module={"method": "euler"},
634
+ infer_noise_scheduler_cfg={"validation_steps": 50},
635
+ train_cfg={"cond_mask_prob": 0.1},
636
+ test_cfg={
637
+ "text_guidance_scale": 1.5,
638
+ },
639
+ ).to(device)
hymotion/prompt_engineering/model_constants.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "REWRITE_AND_INFER_TIME_PROMPT_FORMAT",
3
+ ]
4
+
5
+ REWRITE_AND_INFER_TIME_PROMPT_FORMAT = """
6
+ # Role
7
+ You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems.
8
+
9
+ # Task
10
+ Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption.
11
+
12
+ # Instructions
13
+
14
+ ### 1. Duration Estimation (frame_count)
15
+ - Analyze the complexity, speed, and physical constraints of the described action.
16
+ - Estimate the time required to perform the action in a **smooth, natural, and realistic manner**.
17
+ - Calculate the total duration in frames based on a **30 fps** (frames per second) standard.
18
+ - Output strictly as an Integer.
19
+
20
+ ### 2. Caption Refinement (short_caption)
21
+ - Generate a refined, grammatically correct version of the input description in **English**.
22
+ - **Strict Constraints**:
23
+ - You must **PRESERVE** the original sequence of events (chronological order).
24
+ - You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly").
25
+ - **DO NOT** add new sub-actions or hallucinate details not present in the input.
26
+ - **DO NOT** delete any specific movements.
27
+ - The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request.
28
+
29
+ ### 3. Output Format
30
+ - Return **ONLY** a raw JSON object.
31
+ - Do not use Markdown formatting (i.e., do not use ```json ... ```).
32
+ - Ensure the JSON is valid and parsable.
33
+
34
+ # JSON Structure
35
+ {{
36
+ "duration": <Integer, frames at 30fps>,
37
+ "short_caption": "<String, the refined English description>"
38
+ }}
39
+
40
+ # Input
41
+ {}
42
+ """
hymotion/prompt_engineering/prompt_rewrite.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prompt_rewrite.py
2
+ import base64
3
+ import concurrent.futures
4
+ import datetime
5
+ import hashlib
6
+ import hmac
7
+ import json
8
+ import logging
9
+ import random
10
+ import re
11
+ import time
12
+ import uuid
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
15
+
16
+ from openai import OpenAI
17
+ from requests import exceptions as req_exc
18
+
19
+ from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT
20
+
21
+ # logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ @dataclass
25
+ class ApiConfig:
26
+ host: str
27
+ user: str
28
+ apikey: str
29
+ model: str
30
+ api_version: Optional[str] = None
31
+ timeout: int = 3600
32
+ source: str = "hymotion"
33
+
34
+
35
+ @dataclass
36
+ class RetryConfig:
37
+ max_retries: int = 20
38
+ base_delay: float = 1.0
39
+ timeout: float = 30.0
40
+ retry_status: Tuple[int, ...] = (429, 500, 502, 503, 504)
41
+ max_delay: float = 1.0
42
+
43
+
44
+ class ApiError(Exception):
45
+ pass
46
+
47
+
48
+ class ResponseParseError(Exception):
49
+ pass
50
+
51
+
52
+ class OpenAIChatApi:
53
+ def __init__(self, config: ApiConfig) -> None:
54
+ self.logger = logging.getLogger(__name__)
55
+ self.config = config
56
+ self.client = OpenAI(
57
+ api_key=self.config.apikey,
58
+ base_url=self.config.host,
59
+ )
60
+
61
+ def call_data_eval(self, data: Union[str, Dict[str, Any]]):
62
+ if isinstance(data, dict) and "messages" in data:
63
+ raw_msgs = data["messages"]
64
+ messages: List[Dict[str, str]] = []
65
+ for m in raw_msgs:
66
+ role = m.get("role", "user")
67
+ content = m.get("content", "")
68
+ if isinstance(content, list):
69
+ parts = []
70
+ for p in content:
71
+ if isinstance(p, dict) and ("text" in p):
72
+ parts.append(str(p.get("text", "")))
73
+ content = " ".join([t for t in parts if t])
74
+ elif not isinstance(content, str):
75
+ content = str(content)
76
+ messages.append({"role": role, "content": content})
77
+ payload = {"model": self.config.model, "messages": messages}
78
+ for k in (
79
+ "temperature",
80
+ "top_p",
81
+ "max_tokens",
82
+ "n",
83
+ "stop",
84
+ "presence_penalty",
85
+ "frequency_penalty",
86
+ "user",
87
+ ):
88
+ if k in data:
89
+ payload[k] = data[k]
90
+ else:
91
+ payload = {"model": self.config.model, "messages": [{"role": "user", "content": str(data)}]}
92
+ try:
93
+ resp = self.client.chat.completions.create(**payload)
94
+ return resp
95
+ except Exception as e:
96
+ self.logger.error(f"OpenAI API call failed: {e}")
97
+ raise ApiError(f"OpenAI API call failed: {e}") from e
98
+
99
+
100
+ class ResponseParser:
101
+ def __init__(self):
102
+ self.logger = logging.getLogger(__name__)
103
+
104
+ def call_data_eval_with_retry(
105
+ self, api: Union[OpenAIChatApi], data: str, retry_config: Optional[RetryConfig] = None
106
+ ) -> Tuple[Union[Dict[str, Any], int], float, float]:
107
+ if retry_config is None:
108
+ retry_config = RetryConfig()
109
+
110
+ last_error = None
111
+ for attempt in range(retry_config.max_retries):
112
+ start_time = time.time()
113
+ cost = 0.0
114
+
115
+ try:
116
+ result = self._execute_request(api, data)
117
+ end_time = time.time()
118
+ parsed_result = self._parse_answer(result)
119
+ self._validate_result(parsed_result)
120
+ return parsed_result, cost, end_time - start_time
121
+
122
+ except (
123
+ concurrent.futures.TimeoutError,
124
+ req_exc.RequestException,
125
+ json.JSONDecodeError,
126
+ ValueError,
127
+ TypeError,
128
+ ResponseParseError,
129
+ ) as e:
130
+ last_error = e
131
+ self.logger.warning(f"Attempt {attempt + 1} failed: {e}")
132
+ if isinstance(e, req_exc.RequestException) and hasattr(e, "response"):
133
+ if e.response is not None and e.response.status_code not in retry_config.retry_status:
134
+ raise ApiError(f"Non-retryable error: {e.response.status_code}") from e
135
+ if attempt < retry_config.max_retries - 1:
136
+ delay = self._calculate_delay(attempt, retry_config)
137
+ self.logger.info(f"JSON parsing failed, {delay:.1f} seconds later retry...")
138
+ time.sleep(delay)
139
+
140
+ raise ApiError(f"Retry {retry_config.max_retries} times but still failed") from last_error
141
+
142
+ def _execute_request(self, api: Union[OpenAIChatApi], data: str) -> Dict[str, Any]:
143
+ response = api.call_data_eval(data)
144
+
145
+ try:
146
+ if hasattr(response, "model_dump"):
147
+ return response.model_dump()
148
+ if isinstance(response, dict):
149
+ return response
150
+ if hasattr(response, "__dict__"):
151
+ return json.loads(json.dumps(response.__dict__, default=str))
152
+ except Exception as e:
153
+ raise ResponseParseError(f"Unable to parse OpenAI returned object: {type(response)} - {e}") from e
154
+
155
+ raise ResponseParseError(f"Unknown response type: {type(response)}")
156
+
157
+ def _extract_cost(self, payload: Dict[str, Any]) -> float:
158
+ try:
159
+ return float(payload.get("cost_info", {}).get("cost", 0)) / 1e6
160
+ except (AttributeError, KeyError):
161
+ return 0.0
162
+
163
+ def _validate_result(self, result: Union[Dict[str, Any], int]) -> None:
164
+ if isinstance(result, int):
165
+ return
166
+ elif isinstance(result, dict):
167
+ required_fields = ["duration", "short_caption"]
168
+ for field in required_fields:
169
+ if not isinstance(result.get(field), (int, str)):
170
+ raise ResponseParseError(f"LLM returned invalid format: {field}")
171
+ else:
172
+ raise ResponseParseError(f"Unsupported answer type: {type(result)}")
173
+
174
+ def _calculate_delay(self, attempt: int, config: RetryConfig) -> float:
175
+ delay = config.base_delay * (2**attempt) * (0.5 + random.random())
176
+ return min(delay, config.max_delay)
177
+
178
+ def _parse_answer(self, payload: Dict[str, Any]) -> Dict[str, Any]:
179
+ if isinstance(payload, dict) and "choices" in payload:
180
+ return self._parse_from_choices_field(payload)
181
+
182
+ raise ResponseParseError("Unknown response format: expected choices")
183
+
184
+ def _parse_from_choices_field(self, payload: Dict[str, Any]) -> Dict[str, Any]:
185
+ choices = payload.get("choices") or []
186
+ if not choices:
187
+ raise ResponseParseError("OpenAI returned empty")
188
+
189
+ content = self._extract_content_from_choice(choices[0])
190
+
191
+ if not isinstance(content, str) or not content.strip():
192
+ raise ResponseParseError("OpenAI returned no valid content")
193
+
194
+ return self._parse_json_content(content)
195
+
196
+ def _extract_content_from_choice(self, choice: Any) -> Optional[str]:
197
+ content = None
198
+
199
+ if isinstance(choice, dict):
200
+ # Try message content first
201
+ msg = choice.get("message") or {}
202
+ content = msg.get("content")
203
+ # Fallback to delta content or text
204
+ if content is None:
205
+ delta = choice.get("delta") or {}
206
+ content = delta.get("content", choice.get("text"))
207
+ else:
208
+ # Handle object-like choice (e.g. Pydantic model)
209
+ msg = getattr(choice, "message", None)
210
+ if msg is not None:
211
+ content = getattr(msg, "content", None)
212
+
213
+ if content is None:
214
+ delta = getattr(choice, "delta", None)
215
+ if delta is not None:
216
+ content = getattr(delta, "content", None)
217
+
218
+ if content is None:
219
+ content = getattr(choice, "text", None)
220
+
221
+ return content
222
+
223
+ def _parse_json_content(self, content: str) -> Dict[str, Any]:
224
+ cleaned = self._cleanup_fenced_json(content)
225
+ try:
226
+ return json.loads(cleaned)
227
+ except json.JSONDecodeError as e:
228
+ self.logger.warning(f"JSON parsing failed, original content: {cleaned[:500]}...")
229
+ raise ResponseParseError(f"JSON parsing failed: {e}") from e
230
+
231
+ def _cleanup_fenced_json(self, text: str) -> str:
232
+ text = text.strip()
233
+ if text.startswith("```"):
234
+ text = re.sub(r"^```(?:json)?\s*", "", text)
235
+ text = re.sub(r"\s*```$", "", text)
236
+ if not text.lstrip().startswith("{") and "{" in text and "}" in text:
237
+ start = text.find("{")
238
+ end = text.rfind("}")
239
+ if 0 <= start < end:
240
+ text = text[start : end + 1]
241
+ return text
242
+
243
+
244
+ class PromptRewriter:
245
+ def __init__(self, host: Optional[str] = None, parser: Optional[ResponseParser] = None):
246
+ self.parser = parser or ResponseParser()
247
+ self.logger = logging.getLogger(__name__)
248
+ self.api = OpenAIChatApi(
249
+ ApiConfig(
250
+ host=host,
251
+ user="",
252
+ apikey="EMPTY",
253
+ model="Qwen3-30B-A3B-SFT",
254
+ api_version="",
255
+ )
256
+ )
257
+
258
+ def rewrite_prompt_and_infer_time(
259
+ self,
260
+ text: str,
261
+ prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT,
262
+ retry_config: Optional[RetryConfig] = None,
263
+ ) -> Tuple[float, str]:
264
+ self.logger.info("Start rewriting prompt...")
265
+ try:
266
+ result, cost, elapsed = self.parser.call_data_eval_with_retry(
267
+ self.api, prompt_format.format(text), retry_config
268
+ )
269
+ self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s")
270
+ return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
271
+
272
+ except Exception as e:
273
+ self.logger.error(f"Prompt rewriting failed: {e}")
274
+ raise
275
+
276
+
277
+ if __name__ == "__main__":
278
+ # python -m hymotion.prompt_engineering.prompt_rewrite
279
+
280
+ logging.basicConfig(level=logging.INFO)
281
+ text = "person jumps after they runs"
282
+ prompt_rewriter = PromptRewriter()
283
+ result = prompt_rewriter.rewrite_prompt_and_infer_time(text)
284
+ print(result)
hymotion/utils/configs.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import copy
3
+ import os.path as osp
4
+ import platform
5
+ import re
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import types
10
+ import uuid
11
+ from importlib import import_module
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Iterator, NoReturn, Optional, Union
14
+ import yaml
15
+
16
+ from .misc import import_modules_from_strings
17
+ from .path import check_file_exist
18
+
19
+ BASE_KEY = "_base_"
20
+ DELETE_KEY = "_delete_"
21
+ RESERVED_KEYS = ["filename", "text", "pretty_text"]
22
+
23
+
24
+ class Config:
25
+ def __init__(
26
+ self,
27
+ cfg_dict: Optional[dict] = None,
28
+ cfg_text: Optional[str] = None,
29
+ filename: Optional[str] = None,
30
+ ) -> None:
31
+ if cfg_dict is None:
32
+ cfg_dict = dict()
33
+ elif not isinstance(cfg_dict, dict):
34
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
35
+ for key in cfg_dict:
36
+ if key in RESERVED_KEYS:
37
+ raise KeyError(f"{key} is reserved for config file")
38
+
39
+ if isinstance(filename, Path):
40
+ filename = str(filename)
41
+
42
+ super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
43
+ super(Config, self).__setattr__("_filename", filename)
44
+ if cfg_text:
45
+ text = cfg_text
46
+ elif filename:
47
+ with open(filename, "r") as f:
48
+ text = f.read()
49
+ else:
50
+ text = ""
51
+ super(Config, self).__setattr__("_text", text)
52
+
53
+ @staticmethod
54
+ def fromfile(
55
+ filename: str,
56
+ use_predefined_variables: bool = True,
57
+ import_custom_modules: bool = True,
58
+ ) -> "Config":
59
+ if isinstance(filename, Path):
60
+ filename = str(filename)
61
+ cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
62
+ if import_custom_modules and cfg_dict.get("custom_imports", None):
63
+ import_modules_from_strings(**cfg_dict["custom_imports"])
64
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
65
+
66
+ @staticmethod
67
+ def _file2dict(filename: str, use_predefined_variables: bool = True) -> tuple[dict, str]:
68
+ filename = osp.abspath(osp.expanduser(filename))
69
+ check_file_exist(filename)
70
+ fileExtname = osp.splitext(filename)[1]
71
+ if fileExtname not in [".py"]:
72
+ raise IOError("Only py type are supported now!")
73
+
74
+ cfg_dict = {}
75
+
76
+ with tempfile.TemporaryDirectory() as temp_config_dir:
77
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname)
78
+ if platform.system() == "Windows":
79
+ temp_config_file.close()
80
+ temp_config_name = osp.basename(temp_config_file.name)
81
+ # Substitute predefined variables
82
+ if use_predefined_variables:
83
+ Config._substitute_predefined_vars(filename, temp_config_file.name)
84
+ else:
85
+ shutil.copyfile(filename, temp_config_file.name)
86
+ # Substitute base variables from placeholders to strings
87
+ base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name)
88
+
89
+ if filename.endswith(".py"):
90
+ temp_module_name = osp.splitext(temp_config_name)[0]
91
+ sys.path.insert(0, temp_config_dir)
92
+ Config._validate_py_syntax(filename)
93
+ mod = import_module(temp_module_name)
94
+ sys.path.pop(0)
95
+ cfg_dict = {
96
+ name: value
97
+ for name, value in mod.__dict__.items()
98
+ if not name.startswith("__")
99
+ and not isinstance(value, types.ModuleType)
100
+ and not isinstance(value, types.FunctionType)
101
+ }
102
+ # delete imported module
103
+ del sys.modules[temp_module_name]
104
+
105
+ # close temp file
106
+ temp_config_file.close()
107
+
108
+ cfg_text = filename + "\n"
109
+ with open(filename, "r", encoding="utf-8") as f:
110
+ # Setting encoding explicitly to resolve coding issue on windows
111
+ cfg_text += f.read()
112
+
113
+ if BASE_KEY in cfg_dict:
114
+ cfg_dir = osp.dirname(filename)
115
+ base_filename = cfg_dict.pop(BASE_KEY)
116
+ base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
117
+
118
+ cfg_dict_list = list()
119
+ cfg_text_list = list()
120
+ for f in base_filename:
121
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
122
+ cfg_dict_list.append(_cfg_dict)
123
+ cfg_text_list.append(_cfg_text)
124
+
125
+ base_cfg_dict = dict()
126
+ for c in cfg_dict_list:
127
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
128
+ if len(duplicate_keys) > 0:
129
+ raise KeyError("Duplicate key is not allowed among bases. " f"Duplicate keys: {duplicate_keys}")
130
+ base_cfg_dict.update(c)
131
+
132
+ # Substitute base variables from strings to their actual values
133
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict)
134
+ assert isinstance(cfg_dict, dict)
135
+
136
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
137
+ cfg_dict = base_cfg_dict
138
+
139
+ # merge cfg_text
140
+ cfg_text_list.append(cfg_text)
141
+ cfg_text = "\n".join(cfg_text_list)
142
+
143
+ return cfg_dict, cfg_text
144
+
145
+ @staticmethod
146
+ def _validate_py_syntax(filename: str) -> None:
147
+ with open(filename, "r", encoding="utf-8") as f:
148
+ # Setting encoding explicitly to resolve coding issue on windows
149
+ content = f.read()
150
+ try:
151
+ ast.parse(content)
152
+ except SyntaxError as e:
153
+ raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}")
154
+
155
+ @staticmethod
156
+ def _pre_substitute_base_vars(filename: str, temp_config_name: str) -> dict:
157
+ """Substitute base variable placehoders to string, so that parsing would work."""
158
+ with open(filename, "r", encoding="utf-8") as f:
159
+ config_file = f.read()
160
+ base_var_dict = {}
161
+ regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
162
+ base_vars = set(re.findall(regexp, config_file))
163
+ for base_var in base_vars:
164
+ randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
165
+ base_var_dict[randstr] = base_var
166
+ regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
167
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
168
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
169
+ tmp_config_file.write(config_file)
170
+ return base_var_dict
171
+
172
+ @staticmethod
173
+ def _substitute_base_vars(
174
+ cfg: Union[dict, list, tuple, str],
175
+ base_var_dict: dict,
176
+ base_cfg: dict,
177
+ ) -> Union[dict, list, tuple, str]:
178
+ """Substitute variable strings to their actual values."""
179
+ cfg = copy.deepcopy(cfg)
180
+
181
+ if isinstance(cfg, dict):
182
+ for k, v in cfg.items():
183
+ if isinstance(v, str) and v in base_var_dict:
184
+ new_v = base_cfg
185
+ for new_k in base_var_dict[v].split("."):
186
+ new_v = new_v[new_k]
187
+ cfg[k] = new_v
188
+ elif isinstance(v, (list, tuple, dict)):
189
+ cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
190
+ elif isinstance(cfg, tuple):
191
+ cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg)
192
+ elif isinstance(cfg, list):
193
+ cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg]
194
+ elif isinstance(cfg, str) and cfg in base_var_dict:
195
+ new_v = base_cfg
196
+ for new_k in base_var_dict[cfg].split("."):
197
+ new_v = new_v[new_k]
198
+ cfg = new_v
199
+
200
+ return cfg
201
+
202
+ @staticmethod
203
+ def _substitute_predefined_vars(filename: str, temp_config_name: str) -> None:
204
+ file_dirname = osp.dirname(filename)
205
+ file_basename = osp.basename(filename)
206
+ file_basename_no_extension = osp.splitext(file_basename)[0]
207
+ file_extname = osp.splitext(filename)[1]
208
+ support_templates = dict(
209
+ fileDirname=file_dirname,
210
+ fileBasename=file_basename,
211
+ fileBasenameNoExtension=file_basename_no_extension,
212
+ fileExtname=file_extname,
213
+ )
214
+ with open(filename, "r", encoding="utf-8") as f:
215
+ config_file = f.read()
216
+ for key, value in support_templates.items():
217
+ regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
218
+ value = value.replace("\\", "/")
219
+ config_file = re.sub(regexp, value, config_file)
220
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
221
+ tmp_config_file.write(config_file)
222
+
223
+ @staticmethod
224
+ def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict:
225
+ b = b.copy()
226
+ for k, v in a.items():
227
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
228
+ k = int(k)
229
+ if len(b) <= k:
230
+ raise KeyError(f"Index {k} exceeds the length of list {b}")
231
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
232
+ elif isinstance(v, dict):
233
+ if k in b and not v.pop(DELETE_KEY, False):
234
+ allowed_types = (dict, list) if allow_list_keys else dict
235
+ if not isinstance(b[k], allowed_types):
236
+ raise TypeError(
237
+ f"{k}={v} in child config cannot inherit from "
238
+ f"base because {k} is a dict in the child config "
239
+ f"but is of type {type(b[k])} in base config. "
240
+ f"You may set `{DELETE_KEY}=True` to ignore the "
241
+ f"base config."
242
+ )
243
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
244
+ else:
245
+ b[k] = ConfigDict(v)
246
+ else:
247
+ b[k] = v
248
+ return b
249
+
250
+ def to_dict(self) -> Any:
251
+ def convert_configdict(obj):
252
+ if isinstance(obj, ConfigDict):
253
+ return {k: convert_configdict(v) for k, v in obj.items()}
254
+ elif isinstance(obj, dict):
255
+ return {k: convert_configdict(v) for k, v in obj.items()}
256
+ elif isinstance(obj, (list, tuple)):
257
+ return [convert_configdict(item) for item in obj]
258
+ else:
259
+ return obj
260
+
261
+ return convert_configdict(self._cfg_dict)
262
+
263
+ @classmethod
264
+ def from_dict(cls, cfg_dict: dict, filename: Optional[str] = None) -> "Config":
265
+ return cls(cfg_dict=cfg_dict, filename=filename)
266
+
267
+ def save_yaml(self, filename: str) -> None:
268
+ with open(filename, "w", encoding="utf-8") as f:
269
+ yaml.safe_dump(self.to_dict(), f, default_flow_style=False, indent=2)
270
+
271
+ @classmethod
272
+ def load_yaml(cls, filename: str) -> "Config":
273
+ with open(filename, "r", encoding="utf-8") as f:
274
+ cfg_dict = yaml.safe_load(f)
275
+ return cls.from_dict(cfg_dict, filename=filename)
276
+
277
+ def __repr__(self) -> str:
278
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
279
+
280
+ def __len__(self) -> int:
281
+ return len(self._cfg_dict)
282
+
283
+ def __getattr__(self, name: str) -> Any:
284
+ return getattr(self._cfg_dict, name)
285
+
286
+ def __getitem__(self, name: str) -> Any:
287
+ return self._cfg_dict.__getitem__(name)
288
+
289
+ def __setattr__(self, name: str, value: Any) -> None:
290
+ if isinstance(value, dict):
291
+ value = ConfigDict(value)
292
+ self._cfg_dict.__setattr__(name, value)
293
+
294
+ def __setitem__(self, name: str, value: Any) -> None:
295
+ if isinstance(value, dict):
296
+ value = ConfigDict(value)
297
+ self._cfg_dict.__setitem__(name, value)
298
+
299
+ def __iter__(self) -> Iterator[Any]:
300
+ return iter(self._cfg_dict)
301
+
302
+ def __getstate__(self) -> tuple[dict, str, str]:
303
+ return (self._cfg_dict, self._filename, self._text)
304
+
305
+ def __copy__(self) -> "Config":
306
+ cls = self.__class__
307
+ other = cls.__new__(cls)
308
+ other.__dict__.update(self.__dict__)
309
+
310
+ return other
311
+
312
+ def __deepcopy__(self, memo: dict) -> "Config":
313
+ cls = self.__class__
314
+ other = cls.__new__(cls)
315
+ memo[id(self)] = other
316
+
317
+ for key, value in self.__dict__.items():
318
+ super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
319
+
320
+ return other
321
+
322
+
323
+ class ConfigDict(Dict):
324
+ def __missing__(self, name: str) -> NoReturn:
325
+ raise KeyError(name)
326
+
327
+ def __getattr__(self, name: str) -> Any:
328
+ try:
329
+ return self[name]
330
+ except KeyError:
331
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
332
+
333
+ def to_dict(self) -> Any:
334
+ def convert_configdict(obj):
335
+ if isinstance(obj, ConfigDict):
336
+ return {k: convert_configdict(v) for k, v in obj.items()}
337
+ elif isinstance(obj, dict):
338
+ return {k: convert_configdict(v) for k, v in obj.items()}
339
+ elif isinstance(obj, (list, tuple)):
340
+ return [convert_configdict(item) for item in obj]
341
+ else:
342
+ return obj
343
+
344
+ return convert_configdict(dict(self))
hymotion/utils/geometry.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+ def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
10
+ """
11
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
12
+ using Gram--Schmidt orthogonalization per Section B of [1].
13
+ Args:
14
+ d6: 6D rotation representation, of size (*, 6)
15
+
16
+ Returns:
17
+ batch of rotation matrices of size (*, 3, 3)
18
+
19
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
20
+ On the Continuity of Rotation Representations in Neural Networks.
21
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
22
+ Retrieved from http://arxiv.org/abs/1812.07035
23
+ """
24
+
25
+ a1, a2 = d6[..., :3], d6[..., 3:]
26
+ b1 = F.normalize(a1, dim=-1)
27
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
28
+ b2 = F.normalize(b2, dim=-1)
29
+ b3 = torch.cross(b1, b2, dim=-1)
30
+ return torch.stack((b1, b2, b3), dim=-2)
31
+
32
+
33
+ def matrix_to_rotation_6d(matrix: Tensor) -> Tensor:
34
+ """
35
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
36
+ by dropping the last row. Note that 6D representation is not unique.
37
+ Args:
38
+ matrix: batch of rotation matrices of size (*, 3, 3)
39
+
40
+ Returns:
41
+ 6D rotation representation, of size (*, 6)
42
+
43
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
44
+ On the Continuity of Rotation Representations in Neural Networks.
45
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
46
+ Retrieved from http://arxiv.org/abs/1812.07035
47
+ """
48
+ batch_dim = matrix.size()[:-2]
49
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
50
+
51
+
52
+ def standardize_quaternion(quaternions: Tensor) -> Tensor:
53
+ """
54
+ Convert a unit quaternion to a standard form: one in which the real
55
+ part is non negative.
56
+
57
+ Args:
58
+ quaternions: Quaternions with real part first,
59
+ as tensor of shape (..., 4).
60
+
61
+ Returns:
62
+ Standardized quaternions as tensor of shape (..., 4).
63
+ """
64
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
65
+
66
+
67
+ def _sqrt_positive_part(x: Tensor) -> Tensor:
68
+ """Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0."""
69
+ ret = torch.zeros_like(x)
70
+ positive_mask = x > 0
71
+ if torch.is_grad_enabled():
72
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
73
+ else:
74
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
75
+ return ret
76
+
77
+
78
+ def matrix_to_quaternion(matrix: Tensor) -> Tensor:
79
+ """Convert rotations given as rotation matrices to quaternions.
80
+
81
+ Args:
82
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
83
+
84
+ Returns:
85
+ quaternions with real part first, as tensor of shape (..., 4).
86
+ """
87
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
88
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
89
+
90
+ batch_dim = matrix.shape[:-2]
91
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
92
+
93
+ q_abs = _sqrt_positive_part(
94
+ torch.stack(
95
+ [
96
+ 1.0 + m00 + m11 + m22,
97
+ 1.0 + m00 - m11 - m22,
98
+ 1.0 - m00 + m11 - m22,
99
+ 1.0 - m00 - m11 + m22,
100
+ ],
101
+ dim=-1,
102
+ )
103
+ )
104
+
105
+ # we produce the desired quaternion multiplied by each of r, i, j, k
106
+ quat_by_rijk = torch.stack(
107
+ [
108
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
109
+ # `int`.
110
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
111
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
112
+ # `int`.
113
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
114
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
115
+ # `int`.
116
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
117
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
118
+ # `int`.
119
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
120
+ ],
121
+ dim=-2,
122
+ )
123
+
124
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
125
+ # the candidate won't be picked.
126
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
127
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
128
+
129
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
130
+ # forall i; we pick the best-conditioned one (with the largest denominator)
131
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
132
+ return standardize_quaternion(out)
133
+
134
+
135
+ def quaternion_to_axis_angle(quaternions: Tensor) -> Tensor:
136
+ """Convert rotations given as quaternions to axis/angle.
137
+
138
+ Args:
139
+ quaternions: quaternions with real part first,
140
+ as tensor of shape (..., 4).
141
+
142
+ Returns:
143
+ Rotations given as a vector in axis angle form, as a tensor
144
+ of shape (..., 3), where the magnitude is the angle
145
+ turned anticlockwise in radians around the vector's
146
+ direction.
147
+ """
148
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
149
+ half_angles = torch.atan2(norms, quaternions[..., :1])
150
+ angles = 2 * half_angles
151
+ eps = 1e-6
152
+ small_angles = angles.abs() < eps
153
+ sin_half_angles_over_angles = torch.empty_like(angles)
154
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
155
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
156
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
157
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
158
+ return quaternions[..., 1:] / sin_half_angles_over_angles
159
+
160
+
161
+ def matrix_to_axis_angle(matrix: Tensor) -> Tensor:
162
+ """Convert rotations given as rotation matrices to axis/angle.
163
+
164
+ Args:
165
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
166
+
167
+ Returns:
168
+ Rotations given as a vector in axis angle form, as a tensor
169
+ of shape (..., 3), where the magnitude is the angle
170
+ turned anticlockwise in radians around the vector's
171
+ direction.
172
+ """
173
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
174
+
175
+
176
+ def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
177
+ """Convert rotations given as quaternions to rotation matrices.
178
+
179
+ Args:
180
+ quaternions: quaternions with real part first,
181
+ as tensor of shape (..., 4).
182
+
183
+ Returns:
184
+ Rotation matrices as tensor of shape (..., 3, 3).
185
+ """
186
+ r, i, j, k = torch.unbind(quaternions, -1)
187
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
188
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
189
+
190
+ o = torch.stack(
191
+ (
192
+ 1 - two_s * (j * j + k * k),
193
+ two_s * (i * j - k * r),
194
+ two_s * (i * k + j * r),
195
+ two_s * (i * j + k * r),
196
+ 1 - two_s * (i * i + k * k),
197
+ two_s * (j * k - i * r),
198
+ two_s * (i * k - j * r),
199
+ two_s * (j * k + i * r),
200
+ 1 - two_s * (i * i + j * j),
201
+ ),
202
+ -1,
203
+ )
204
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
205
+
206
+
207
+ def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor:
208
+ """Convert rotations given as axis/angle to quaternions.
209
+
210
+ Args:
211
+ axis_angle: Rotations given as a vector in axis angle form,
212
+ as a tensor of shape (..., 3), where the magnitude is
213
+ the angle turned anticlockwise in radians around the
214
+ vector's direction.
215
+
216
+ Returns:
217
+ quaternions with real part first, as tensor of shape (..., 4).
218
+ """
219
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
220
+ half_angles = angles * 0.5
221
+ eps = 1e-6
222
+ small_angles = angles.abs() < eps
223
+ sin_half_angles_over_angles = torch.empty_like(angles)
224
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
225
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
226
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
227
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
228
+ quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)
229
+ return quaternions
230
+
231
+
232
+ def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor:
233
+ """Convert rotations given as axis/angle to rotation matrices.
234
+
235
+ Args:
236
+ axis_angle: Rotations given as a vector in axis angle form,
237
+ as a tensor of shape (..., 3), where the magnitude is
238
+ the angle turned anticlockwise in radians around the
239
+ vector's direction.
240
+
241
+ Returns:
242
+ Rotation matrices as tensor of shape (..., 3, 3).
243
+ """
244
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
245
+
246
+
247
+ def get_T_w2c_from_wcparams(
248
+ global_orient_w: Tensor, transl_w: Tensor, global_orient_c: Tensor, transl_c: Tensor, offset: Tensor
249
+ ) -> Tensor:
250
+ """
251
+ Args:
252
+ global_orient_w: Tensor, (F, 3)
253
+ transl_w: Tensor, (F, 3)
254
+ global_orient_c: Tensor, (F, 3)
255
+ transl_c: Tensor, (F, 3)
256
+ offset: Tensor, (*, 3)
257
+ Returns:
258
+ T_w2c: Tensor, (F, 4, 4)
259
+ """
260
+ assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2
261
+ assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2
262
+
263
+ R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3)
264
+ t_w = transl_w # (F, 3)
265
+ R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3)
266
+ t_c = transl_c # (F, 3)
267
+
268
+ R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3)
269
+ t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3)
270
+ T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4)
271
+ T_w2c[..., :3, :3] = R_w2c # (F, 3, 3)
272
+ T_w2c[..., :3, 3] = t_w2c # (F, 3)
273
+ return T_w2c
274
+
275
+
276
+ def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]):
277
+ """
278
+ Args:
279
+ R_w2c: (*, 3, 3)
280
+ Returns:
281
+ R_c2gv: (*, 3, 3)
282
+ """
283
+ if isinstance(axis_gravity_in_w, list):
284
+ axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord
285
+ axis_z_in_c = torch.tensor([0, 0, 1]).float()
286
+
287
+ # get gv-coord axes in in c-coord
288
+ axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3)
289
+ axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1)
290
+ # normalize
291
+ axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True)
292
+ axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5)
293
+ axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv
294
+ axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1)
295
+
296
+ R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3)
297
+ R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3)
298
+ return R_c2gv
299
+
300
+
301
+ def get_c_rootparam(global_orient: Tensor, transl: Tensor, T_w2c: Tensor, offset: Tensor) -> Tuple[Tensor, Tensor]:
302
+ """
303
+ Args:
304
+ global_orient: Tensor, (F, 3)
305
+ transl: Tensor, (F, 3)
306
+ T_w2c: Tensor, (*, 4, 4)
307
+ offset: Tensor, (3,)
308
+ Returns:
309
+ R_c: Tensor, (F, 3)
310
+ t_c: Tensor, (F, 3)
311
+ """
312
+ assert global_orient.shape == transl.shape and len(global_orient.shape) == 2
313
+ R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3)
314
+ t_w = transl # (F, 3)
315
+
316
+ R_w2c = T_w2c[..., :3, :3] # (*, 3, 3)
317
+ t_w2c = T_w2c[..., :3, 3] # (*, 3)
318
+ if len(R_w2c.shape) == 2:
319
+ R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3)
320
+ t_w2c = t_w2c[None].expand(t_w.size(0), -1)
321
+
322
+ R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3)
323
+ t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3)
324
+ return R_c, t_c
325
+
326
+
327
+ def compute_cam_angvel(R_w2c, padding_last=True):
328
+ """
329
+ R_w2c : (F, 3, 3)
330
+ """
331
+ # R @ R0 = R1, so R = R1 @ R0^T
332
+ cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6)
333
+ # cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS
334
+ assert padding_last
335
+ cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6)
336
+ return cam_angvel.float()
337
+
338
+
339
+ def rot6d_to_rotation_matrix(rot6d):
340
+ """Convert 6D rotation representation to 3x3 rotation matrix.
341
+
342
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
343
+ Args:
344
+ rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
345
+ Returns:
346
+ rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
347
+ """
348
+ # x = rot6d.view(-1, 3, 2)
349
+ x = rot6d.view(*rot6d.shape[:-1], 3, 2)
350
+ a1 = x[..., 0]
351
+ a2 = x[..., 1]
352
+ b1 = F.normalize(a1, dim=-1)
353
+ b2 = F.normalize(a2 - torch.einsum("...i,...i->...", b1, a2).unsqueeze(-1) * b1, dim=-1)
354
+ b3 = torch.cross(b1, b2, dim=-1)
355
+ return torch.stack((b1, b2, b3), dim=-1)
356
+
357
+
358
+ def rotation_matrix_to_rot6d(rotation_matrix):
359
+ """Convert 3x3 rotation matrix to 6D rotation representation.
360
+
361
+ Args:
362
+ rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
363
+ Returns:
364
+ rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
365
+ """
366
+ v1 = rotation_matrix[..., 0:1]
367
+ v2 = rotation_matrix[..., 1:2]
368
+ rot6d = torch.cat([v1, v2], dim=-1).reshape(*v1.shape[:-2], 6)
369
+ return rot6d
370
+
371
+
372
+ def quaternion_to_rotation_matrix(quaternion):
373
+ """Convert quaternion coefficients to rotation matrix.
374
+
375
+ Args:
376
+ quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
377
+ Returns:
378
+ rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3)
379
+ """
380
+
381
+ norm_quaternion = quaternion
382
+ norm_quaternion = norm_quaternion / norm_quaternion.norm(p=2, dim=-1, keepdim=True)
383
+ w, x, y, z = norm_quaternion[..., 0], norm_quaternion[..., 1], norm_quaternion[..., 2], norm_quaternion[..., 3]
384
+
385
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
386
+ wx, wy, wz = w * x, w * y, w * z
387
+ xy, xz, yz = x * y, x * z, y * z
388
+
389
+ rotation_matrix = torch.stack(
390
+ [
391
+ w2 + x2 - y2 - z2,
392
+ 2 * xy - 2 * wz,
393
+ 2 * wy + 2 * xz,
394
+ 2 * wz + 2 * xy,
395
+ w2 - x2 + y2 - z2,
396
+ 2 * yz - 2 * wx,
397
+ 2 * xz - 2 * wy,
398
+ 2 * wx + 2 * yz,
399
+ w2 - x2 - y2 + z2,
400
+ ],
401
+ dim=-1,
402
+ )
403
+ rotation_matrix = rotation_matrix.view(*quaternion.shape[:-1], 3, 3)
404
+ return rotation_matrix
405
+
406
+
407
+ def quaternion_to_angle_axis(quaternion: Tensor) -> Tensor:
408
+ """
409
+ This function is borrowed from https://github.com/kornia/kornia
410
+
411
+ Convert quaternion vector to angle axis of rotation.
412
+
413
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
414
+
415
+ Args:
416
+ quaternion (Tensor): tensor with quaternions.
417
+
418
+ Return:
419
+ Tensor: tensor with angle axis of rotation.
420
+
421
+ Shape:
422
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
423
+ - Output: :math:`(*, 3)`
424
+
425
+ Example:
426
+ >>> quaternion = torch.rand(2, 4) # Nx4
427
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
428
+ """
429
+ if not torch.is_tensor(quaternion):
430
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
431
+
432
+ if not quaternion.shape[-1] == 4:
433
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape))
434
+ # unpack input and compute conversion
435
+ q1: torch.Tensor = quaternion[..., 1]
436
+ q2: torch.Tensor = quaternion[..., 2]
437
+ q3: torch.Tensor = quaternion[..., 3]
438
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
439
+
440
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
441
+ cos_theta: torch.Tensor = quaternion[..., 0]
442
+ two_theta: torch.Tensor = 2.0 * torch.where(
443
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)
444
+ )
445
+
446
+ k_pos: torch.Tensor = two_theta / sin_theta
447
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
448
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
449
+
450
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
451
+ angle_axis[..., 0] += q1 * k
452
+ angle_axis[..., 1] += q2 * k
453
+ angle_axis[..., 2] += q3 * k
454
+ return angle_axis
455
+
456
+
457
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
458
+ """
459
+ This function is borrowed from https://github.com/kornia/kornia
460
+
461
+ Convert 3x4 rotation matrix to 4d quaternion vector
462
+
463
+ This algorithm is based on algorithm described in
464
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
465
+
466
+ Args:
467
+ rotation_matrix (Tensor): the rotation matrix to convert.
468
+
469
+ Return:
470
+ Tensor: the rotation in quaternion
471
+
472
+ Shape:
473
+ - Input: :math:`(N, 3, 4)`
474
+ - Output: :math:`(N, 4)`
475
+
476
+ Example:
477
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
478
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
479
+ """
480
+ if not torch.is_tensor(rotation_matrix):
481
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
482
+
483
+ if len(rotation_matrix.shape) > 3:
484
+ raise ValueError("Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape))
485
+ if not rotation_matrix.shape[-2:] == (3, 4):
486
+ hom = (
487
+ torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
488
+ .reshape(1, 3, 1)
489
+ .expand(rotation_matrix.shape[0], -1, -1)
490
+ )
491
+ rotation_matrix = torch.cat([rotation_matrix, hom], dim=-1)
492
+
493
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
494
+
495
+ mask_d2 = rmat_t[:, 2, 2] < eps
496
+
497
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
498
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
499
+
500
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
501
+ q0 = torch.stack(
502
+ [rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]],
503
+ -1,
504
+ )
505
+ t0_rep = t0.repeat(4, 1).t()
506
+
507
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
508
+ q1 = torch.stack(
509
+ [rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]],
510
+ -1,
511
+ )
512
+ t1_rep = t1.repeat(4, 1).t()
513
+
514
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
515
+ q2 = torch.stack(
516
+ [rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2],
517
+ -1,
518
+ )
519
+ t2_rep = t2.repeat(4, 1).t()
520
+
521
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
522
+ q3 = torch.stack(
523
+ [t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]],
524
+ -1,
525
+ )
526
+ t3_rep = t3.repeat(4, 1).t()
527
+
528
+ mask_c0 = mask_d2 * mask_d0_d1
529
+ mask_c1 = mask_d2 * ~mask_d0_d1
530
+ mask_c2 = ~mask_d2 * mask_d0_nd1
531
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
532
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
533
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
534
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
535
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
536
+
537
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
538
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa # noqa
539
+ q *= 0.5
540
+ return q
541
+
542
+
543
+ def rotation_matrix_to_angle_axis(rotation_matrix):
544
+ """
545
+ This function is borrowed from https://github.com/kornia/kornia
546
+
547
+ Convert 3x4 rotation matrix to Rodrigues vector
548
+
549
+ Args:
550
+ rotation_matrix (Tensor): rotation matrix.
551
+
552
+ Returns:
553
+ Tensor: Rodrigues vector transformation.
554
+
555
+ Shape:
556
+ - Input: :math:`(N, 3, 4)`
557
+ - Output: :math:`(N, 3)`
558
+
559
+ Example:
560
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
561
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
562
+ """
563
+ origin_shape = rotation_matrix.shape[:-2]
564
+ flat_rot = rotation_matrix.reshape(-1, *rotation_matrix.shape[-2:])
565
+ if flat_rot.shape[1:] == (3, 3):
566
+ rot_mat = flat_rot
567
+ hom = (
568
+ torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
569
+ .reshape(1, 3, 1)
570
+ .expand(rot_mat.shape[0], -1, -1)
571
+ )
572
+ flat_rot = torch.cat([rot_mat, hom], dim=-1)
573
+
574
+ quaternion = rotation_matrix_to_quaternion(flat_rot)
575
+ aa = quaternion_to_angle_axis(quaternion)
576
+ aa[torch.isnan(aa)] = 0.0
577
+ aa = aa.reshape(*origin_shape, 3)
578
+ return aa
579
+
580
+
581
+ def quat_to_rotmat(quat):
582
+ """Convert quaternion coefficients to rotation matrix.
583
+
584
+ Args:
585
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
586
+ Returns:
587
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
588
+ """
589
+ norm_quat = quat
590
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
591
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
592
+
593
+ B = quat.size(0)
594
+
595
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
596
+ wx, wy, wz = w * x, w * y, w * z
597
+ xy, xz, yz = x * y, x * z, y * z
598
+
599
+ rotMat = torch.stack(
600
+ [
601
+ w2 + x2 - y2 - z2,
602
+ 2 * xy - 2 * wz,
603
+ 2 * wy + 2 * xz,
604
+ 2 * wz + 2 * xy,
605
+ w2 - x2 + y2 - z2,
606
+ 2 * yz - 2 * wx,
607
+ 2 * xz - 2 * wy,
608
+ 2 * wx + 2 * yz,
609
+ w2 - x2 - y2 + z2,
610
+ ],
611
+ dim=1,
612
+ ).view(B, 3, 3)
613
+ return rotMat
614
+
615
+
616
+ def angle_axis_to_rotation_matrix(theta):
617
+ """Convert axis-angle representation to rotation matrix.
618
+
619
+ Args:
620
+ theta: size = [B, 3]
621
+ Returns:
622
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
623
+ """
624
+ origin_shape = theta.shape[:-1]
625
+ flat_theta = theta.reshape(-1, 3)
626
+ l1norm = torch.norm(flat_theta + 1e-8, p=2, dim=1)
627
+ angle = torch.unsqueeze(l1norm, -1)
628
+ normalized = torch.div(flat_theta, angle)
629
+ angle = angle * 0.5
630
+ v_cos = torch.cos(angle)
631
+ v_sin = torch.sin(angle)
632
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
633
+ rot_mat = quat_to_rotmat(quat)
634
+ return rot_mat.reshape(*origin_shape, 3, 3)
635
+
636
+
637
+ def rotation_matrix_to_euler_angles(rotation_matrix):
638
+ """Convert 3x3 rotation matrix to Euler angles."""
639
+ is_torch = False
640
+ if isinstance(rotation_matrix, Tensor):
641
+ is_torch = True
642
+ device = rotation_matrix.device
643
+ rotation_matrix = rotation_matrix.cpu().numpy()
644
+ from scipy.spatial.transform import Rotation
645
+
646
+ rot_flat = rotation_matrix.reshape(-1, 3, 3)
647
+ euler_angles = Rotation.from_matrix(rot_flat).as_euler("xyz", degrees=True)
648
+ if is_torch:
649
+ return torch.from_numpy(euler_angles).to(device)
650
+ return euler_angles
651
+
652
+
653
+ def euler_angles_to_rotation_matrix(euler_angles, degrees=True):
654
+ """Convert Euler angles to 3x3 rotation matrix.
655
+
656
+ Args:
657
+ euler_angles: Euler angles in xyz order, shape = [B, 3] or any shape with last dimension 3
658
+ degrees: Whether the angles are in degrees (True) or radians (False)
659
+
660
+ Returns:
661
+ Rotation matrix corresponding to the Euler angles, shape = [..., 3, 3]
662
+ """
663
+ from scipy.spatial.transform import Rotation
664
+
665
+ orig_shape = euler_angles.shape[:-1]
666
+ euler_flat = euler_angles.reshape(-1, 3)
667
+ rot_flat = Rotation.from_euler("xyz", euler_flat, degrees=degrees).as_matrix()
668
+ return rot_flat.reshape(*orig_shape, 3, 3)
669
+
670
+
671
+ def get_local_transl_vel(transl, global_orient_R, fps=30):
672
+ """
673
+ transl velocity is in local coordinate (or, SMPL-coord)
674
+ Args:
675
+ transl: (*, L, 3)
676
+ global_orient: (*, L, 3, 3)
677
+ Returns:
678
+ transl_vel: (*, L, 3)
679
+ """
680
+ transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3)
681
+ transl_vel = torch.cat([torch.zeros_like(transl_vel[:1]), transl_vel], dim=-2) # (B, L, 3) last-padding
682
+ transl_vel = transl_vel * fps
683
+
684
+ # v_local = R^T @ v_global
685
+ local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel)
686
+ return local_transl_vel
687
+
688
+
689
+ def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg):
690
+ s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2]
691
+ focal_length = K_fullimg[..., 0, 0]
692
+
693
+ icx = K_fullimg[..., 0, 2]
694
+ icy = K_fullimg[..., 1, 2]
695
+ sb = s * bbx_xys[..., 2]
696
+ cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9)
697
+ cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9)
698
+ tz = 2 * focal_length / (sb + 1e-9)
699
+
700
+ cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
701
+ return cam_t
702
+
703
+
704
+ def quaternion_fix_continuity(q: Tensor) -> Tensor:
705
+ """Force quaternion continuity across the time dimension by selecting the representation (q or -q) with minimal
706
+ distance (or, equivalently, maximal dot product) between two consecutive frames."""
707
+ assert q.ndim in (
708
+ 2,
709
+ 3,
710
+ ), f"Expected 3D tensor (L, J, 4), or 2D tensor (L, 4), but got shape {q.shape}"
711
+ assert q.shape[-1] == 4, f"Last dimension should be 4 for quaternions, got {q.shape[-1]}"
712
+ if q.shape[0] <= 1:
713
+ return q.clone() # single frame or empty sequence, no need to process
714
+
715
+ result = q.clone()
716
+ # compute the dot product between consecutive frames (L-1, J) or (L-1)
717
+ dot_products = torch.sum(q[1:] * q[:-1], dim=-1)
718
+ # find the negative dot product (indicates need to flip sign)
719
+ flip_mask = dot_products < 0
720
+ # accumulate the flip mask, ensure consistency
721
+ # if a frame needs to be flipped, all subsequent frames need to be flipped the same number of times
722
+ flip_mask = (torch.cumsum(flip_mask.int(), dim=0) % 2).bool()
723
+ # flip the sign of the frames that need to be flipped
724
+ result[1:][flip_mask] *= -1
725
+ return result
726
+
727
+
728
+ def rot_mat2trans_mat(rot_mat: np.ndarray) -> np.ndarray:
729
+ # assert rot_mat.shape == (3, 3)
730
+ trans_mat = np.identity(4)
731
+ if len(rot_mat.shape) == 2:
732
+ trans_mat = trans_mat
733
+ elif len(rot_mat.shape) == 3:
734
+ trans_mat = np.tile(trans_mat, [rot_mat.shape[0], 1, 1])
735
+ elif len(rot_mat.shape) == 4:
736
+ trans_mat = np.tile(trans_mat, [rot_mat.shape[0], rot_mat.shape[1], 1, 1])
737
+ else:
738
+ raise NotImplementedError
739
+ trans_mat[..., :3, :3] = rot_mat
740
+ return trans_mat
741
+
742
+
743
+ def trans2trans_mat(trans: np.ndarray) -> np.ndarray:
744
+ assert trans.shape[-1] == 3
745
+ assert (len(trans.shape) == 1) or (len(trans.shape) == 2) or (len(trans.shape) == 3), trans.shape
746
+ if len(trans.shape) == 1:
747
+ trans_mat = np.identity(4)
748
+ trans_mat[:3, 3] = trans
749
+ elif len(trans.shape) == 2:
750
+ trans_mat = np.tile(np.identity(4), [trans.shape[0], 1, 1])
751
+ trans_mat[:, :3, 3] = trans
752
+ elif len(trans.shape) == 3:
753
+ trans_mat = np.tile(np.identity(4), [trans.shape[0], trans.shape[1], 1, 1])
754
+ trans_mat[:, :, :3, 3] = trans
755
+ else:
756
+ raise NotImplementedError
757
+ return trans_mat
758
+
759
+
760
+ def gaussian_kernel1d(sigma: float, order: int, radius: int) -> np.ndarray:
761
+ """Computes a 1D Gaussian convolution kernel.
762
+
763
+ (from scipy)
764
+ """
765
+ if order < 0:
766
+ raise ValueError("order must be non-negative")
767
+ exponent_range = np.arange(order + 1)
768
+ sigma2 = sigma * sigma
769
+ x = np.arange(-radius, radius + 1)
770
+ phi_x = np.exp(-0.5 / sigma2 * x**2)
771
+ phi_x = phi_x / phi_x.sum()
772
+
773
+ if order == 0:
774
+ return phi_x
775
+ else:
776
+ # f(x) = q(x) * phi(x) = q(x) * exp(p(x))
777
+ # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
778
+ # p'(x) = -1 / sigma ** 2
779
+ # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
780
+ # coefficients of q(x)
781
+ q = np.zeros(order + 1)
782
+ q[0] = 1
783
+ D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
784
+ P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
785
+ Q_deriv = D + P
786
+ for _ in range(order):
787
+ q = Q_deriv.dot(q)
788
+ q = (x[:, None] ** exponent_range).dot(q)
789
+ return q * phi_x
790
+
791
+
792
+ def slice_seq_with_padding(whole_seq: np.ndarray, middle_idx: int, length: int) -> np.ndarray:
793
+ whole_seq_padded = whole_seq.copy()
794
+ if middle_idx - length // 2 < 0:
795
+ # need padding
796
+ l_pad_len = length // 2 - middle_idx
797
+ whole_seq_padded = np.concatenate([np.stack([whole_seq_padded[0]] * l_pad_len), whole_seq_padded], axis=0)
798
+ else:
799
+ l_pad_len = 0
800
+ if middle_idx + length - length // 2 > len(whole_seq):
801
+ r_pad_len = middle_idx + length - length // 2 - len(whole_seq)
802
+ whole_seq_padded = np.concatenate([whole_seq_padded, np.stack([whole_seq_padded[-1]] * r_pad_len)], axis=0)
803
+ else:
804
+ r_pad_len = 0
805
+ assert len(whole_seq_padded) == len(whole_seq) + l_pad_len + r_pad_len
806
+ middle_idx_padded = middle_idx + l_pad_len
807
+ assert middle_idx_padded - length // 2 >= 0
808
+ assert middle_idx_padded + length - length // 2 <= len(whole_seq_padded)
809
+ return whole_seq_padded[middle_idx_padded - length // 2 : middle_idx_padded - length // 2 + length]
810
+
811
+
812
+ def wavg_quaternion_markley(Q: np.ndarray, weights: np.ndarray) -> np.ndarray:
813
+ """
814
+ Averaging Quaternions.
815
+ This is a python implementation of Tolga Birdal's algorithm by https://stackoverflow.com/a/49690919
816
+
817
+ Arguments:
818
+ Q(ndarray): an Mx4 ndarray of quaternions.
819
+ weights(list): an M elements list, a weight for each quaternion.
820
+
821
+ refer to Tolga Birdal's matlab implementation on
822
+ https://ww2.mathworks.cn/matlabcentral/fileexchange/40098-tolgabirdal-averaging_quaternions?s_tid=prof_contriblnk&s_tid=mwa_osa_a
823
+ by Tolga Birdal
824
+ Q is an Mx4 matrix of quaternions. weights is an Mx1 vector, a weight for
825
+ each quaternion.
826
+ Qavg is the weighted average quaternion
827
+ This function is especially useful for example when clustering poses
828
+ after a matching process. In such cases a form of weighting per rotation
829
+ is available (e.g. number of votes), which can guide the trust towards a
830
+ specific pose. weights might then be interpreted as the vector of votes
831
+ per pose.
832
+ Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman.
833
+ "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30,
834
+ no. 4 (2007): 1193-1197.
835
+ """
836
+
837
+ # Form the symmetric accumulator matrix
838
+ # pdb.set_trace()
839
+ A = np.zeros((4, 4))
840
+ M = Q.shape[0]
841
+ wSum = 0
842
+
843
+ for i in range(M):
844
+ q = Q[i, :]
845
+ w_i = weights[i]
846
+ if q[0] < 0:
847
+ # handle the antipodal configuration
848
+ q = -q
849
+ A += w_i * (np.outer(q, q)) # rank 1 update
850
+ wSum += w_i
851
+
852
+ # scale
853
+ A /= wSum
854
+
855
+ # Get the eigenvector corresponding to largest eigen value
856
+ return np.linalg.eigh(A)[1][:, -1]
hymotion/utils/loaders.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+
5
+
6
+ def load_object(module_name, module_args, **extra_args):
7
+ module_args = module_args.copy()
8
+ module_path = ".".join(module_name.split(".")[:-1]).replace("/", ".")
9
+ module = importlib.import_module(module_path)
10
+ name = module_name.split(".")[-1]
11
+ if module_args is None:
12
+ module_args = {}
13
+ module_args.update(extra_args)
14
+ obj = getattr(module, name)(**module_args)
15
+ return obj
16
+
17
+
18
+ def load_module(module_name):
19
+ module_path = module_name.split(".")[0].replace("/", ".")
20
+ module = importlib.import_module(module_path)
21
+ name = module_name.split(".")[-1]
22
+ obj = getattr(module, name)
23
+ return obj
24
+
25
+
26
+ def check_cfg(cfg, global_dict, verbose=True):
27
+ for key, val in cfg.items():
28
+ if isinstance(val, dict):
29
+ check_cfg(val, global_dict, verbose)
30
+ elif isinstance(val, str):
31
+ if val.startswith("$"):
32
+ if verbose:
33
+ print(f" - Update {key} with {val} = {global_dict[val[1:]]}")
34
+ cfg[key] = global_dict[val[1:]]
35
+
36
+
37
+ def read_yaml(yamlname):
38
+ import yaml
39
+
40
+ with open(yamlname, "r", encoding="utf-8") as file:
41
+ try:
42
+ data = yaml.safe_load(file)
43
+ except yaml.constructor.ConstructorError:
44
+ file.seek(0)
45
+ data = yaml.load(file, Loader=yaml.FullLoader)
46
+ if hasattr(data, "to_dict"):
47
+ data = data.to_dict()
48
+ elif hasattr(data, "_cfg_dict"):
49
+ data = dict(data._cfg_dict)
50
+
51
+ return data
52
+
53
+
54
+ def write_yaml(data, yamlname):
55
+ import yaml
56
+
57
+ with open(yamlname, "w", encoding="utf-8") as file:
58
+ yaml.dump(data, file)
59
+
60
+
61
+ def check_input(data, verbose=True):
62
+ data_parent = {}
63
+ if "input" in data:
64
+ if verbose:
65
+ print(" - Check input file list")
66
+ for filename in data.pop("input"):
67
+ cfg_new = read_yaml(filename)
68
+ data_parent.update(cfg_new)
69
+ return data_parent
70
+
71
+
72
+ def merge_dict(dict_A, dict_B, key, verbose=True):
73
+ if isinstance(dict_A[key], dict):
74
+ dict_B = dict_B.copy()
75
+ for key2, val2 in dict_A[key].items():
76
+ if key2 in dict_B[key]:
77
+ merge_dict(dict_A[key], dict_B[key], key2, verbose)
78
+ dict_B[key].pop(key2)
79
+ if len(dict_B[key]) > 0:
80
+ if verbose:
81
+ print(f" - Create {key} with {dict_B[key]}")
82
+ for key2, val2 in dict_B[key].items():
83
+ dict_A[key][key2] = val2
84
+ else:
85
+ if verbose:
86
+ print(f" - Update {key} with {dict_B[key]}")
87
+ dict_A[key] = dict_B[key]
88
+
89
+
90
+ def read_config(cfgname, verbose=True):
91
+ data_base = read_yaml(cfgname)
92
+ data_parent = check_input(data_base, verbose)
93
+ # merge the data_base to data_parent
94
+ for key, val in data_parent.items():
95
+ if key in data_base:
96
+ merge_dict(data_parent, data_base, key, verbose)
97
+ if verbose:
98
+ print(data_parent[key])
99
+ data_base.pop(key)
100
+ data_parent.update(data_base)
101
+ data = data_parent
102
+ check_cfg(data, data, verbose)
103
+ return data
104
+
105
+
106
+ def update_config(config, args):
107
+ for key, value in vars(args).items():
108
+ if key in config.keys() and value is not None:
109
+ config[key] = value
110
+
111
+
112
+ def read_yaml_full(path):
113
+ import yaml
114
+
115
+ with open(path, "r") as f:
116
+ return yaml.load(f, Loader=yaml.FullLoader)
117
+
118
+
119
+ def check_ceph_path(path):
120
+ import os
121
+
122
+ if os.path.exists(path):
123
+ return path
124
+ else:
125
+ raise ValueError(f"{path} not found")
126
+
127
+
128
+ def read_json(filename):
129
+ with open(filename, "r", encoding="utf-8") as f:
130
+ return json.load(f)
131
+
132
+
133
+ def write_json(data, filename):
134
+ with open(filename, "w", encoding="utf-8") as f:
135
+ json.dump(data, f, ensure_ascii=False, indent=4)
136
+
137
+
138
+ def load_h5_dataset(filename, ds_name_list=None, parser=None):
139
+ import h5py
140
+
141
+ # ds for dataset
142
+ if "@" in filename:
143
+ filename, start_end = filename.split("@")
144
+ start = int(start_end.split(":")[0])
145
+ end = int(start_end.split(":")[1])
146
+ else:
147
+ start = None
148
+ end = None
149
+ assert os.path.isfile(filename), "cannot find: {}".format(filename)
150
+
151
+ def load_dict(d):
152
+ ds_dict = {}
153
+ for item in d.keys():
154
+ if ds_name_list is not None and item not in ds_name_list:
155
+ continue
156
+ if isinstance(d[item], h5py._hl.dataset.Dataset):
157
+ ds_dict[item] = d[item][()]
158
+ if parser is not None and item in parser:
159
+ ds_dict[item] = parser[item](ds_dict[item])
160
+ elif isinstance(d[item], h5py._hl.group.Group):
161
+ ds_dict[item] = load_dict(d[item])
162
+ for item in d.attrs.keys():
163
+ ds_dict[item] = d.attrs[item]
164
+ return ds_dict
165
+
166
+ with h5py.File(filename, "r") as f:
167
+ ds_dict = load_dict(f)
168
+ for item in f.attrs.keys():
169
+ ds_dict[item] = f.attrs[item]
170
+ if start is not None and end is not None:
171
+ for key in ["LclRotation", "LclTranslation"]:
172
+ ds_dict[key] = ds_dict[key][start:end]
173
+ return ds_dict
174
+
175
+
176
+ if __name__ == "__main__":
177
+ # hymotion.utils.loaders
178
+ network = load_object("hymotion.utils.base_example.ToyNetwork", {})
179
+ print(network)
180
+ network = load_object("hymotion/utils/base_example.ToyNetwork", {})
181
+ print(network)
182
+ load_object("diffusers.DDIMScheduler", {})
183
+ module = load_object("torch.nn.MSELoss", {"reduction": "none"})
184
+ print(module)
hymotion/utils/misc.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections.abc import Iterable, Sequence
3
+ from importlib import import_module
4
+ from itertools import repeat
5
+ from os import path as osp
6
+ from typing import Any, Callable, Optional, Tuple, Union
7
+
8
+
9
+ def is_str(x: Any) -> bool:
10
+ """Whether the input is an string instance.
11
+
12
+ Note: This method is deprecated since python 2 is no longer supported.
13
+ """
14
+ return isinstance(x, str)
15
+
16
+
17
+ def is_seq_of(seq: Any, expected_type: Any, seq_type: Any = None) -> bool:
18
+ """Check whether it is a sequence of some type.
19
+
20
+ Args:
21
+ seq (Sequence): The sequence to be checked.
22
+ expected_type (type): Expected type of sequence items.
23
+ seq_type (type, optional): Expected sequence type.
24
+ Returns:
25
+ bool: Whether the sequence is valid.
26
+ """
27
+ if seq_type is None:
28
+ exp_seq_type = Sequence
29
+ else:
30
+ assert isinstance(seq_type, type)
31
+ exp_seq_type = seq_type
32
+ if not isinstance(seq, exp_seq_type):
33
+ return False
34
+ for item in seq:
35
+ if not isinstance(item, expected_type):
36
+ return False
37
+ return True
38
+
39
+
40
+ def is_list_of(seq: Any, expected_type: Any) -> bool:
41
+ """Check whether it is a list of some type.
42
+
43
+ A partial method of :func:`is_seq_of`.
44
+ """
45
+ return is_seq_of(seq, expected_type, seq_type=list)
46
+
47
+
48
+ def is_tuple_of(seq: Any, expected_type: Any) -> bool:
49
+ """Check whether it is a tuple of some type.
50
+
51
+ A partial method of :func:`is_seq_of`.
52
+ """
53
+ return is_seq_of(seq, expected_type, seq_type=tuple)
54
+
55
+
56
+ def import_modules_from_strings(
57
+ imports: Union[list[str], str], allow_failed_imports: bool = False
58
+ ) -> Optional[list[Any]]:
59
+ if not imports:
60
+ return
61
+ single_import = False
62
+ if isinstance(imports, str):
63
+ single_import = True
64
+ imports = [imports]
65
+ if not isinstance(imports, list):
66
+ raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
67
+ imported = []
68
+ for imp in imports:
69
+ if not isinstance(imp, str):
70
+ raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
71
+ try:
72
+ imported_tmp = import_module(imp)
73
+ except ImportError:
74
+ if allow_failed_imports:
75
+ warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
76
+ imported_tmp = None
77
+ else:
78
+ raise ImportError
79
+ imported.append(imported_tmp)
80
+ if single_import:
81
+ imported = imported[0]
82
+ return imported
83
+
84
+
85
+ def _ntuple(n: int) -> Callable:
86
+ def parse(x: Any) -> Tuple:
87
+ if isinstance(x, Iterable) and not isinstance(x, str):
88
+ x = tuple(x)
89
+ if len(x) == 1:
90
+ x = tuple(repeat(x[0], n))
91
+ return x
92
+ return tuple(repeat(x, n))
93
+
94
+ return parse
95
+
96
+
97
+ to_1tuple = _ntuple(1)
98
+ to_2tuple = _ntuple(2)
99
+ to_3tuple = _ntuple(3)
100
+ to_4tuple = _ntuple(4)
101
+
102
+
103
+ def seconds_to_hmsms(seconds: float) -> tuple[int, int, int, int]:
104
+ hours, remainder = divmod(seconds, 3600)
105
+ minutes, remainder = divmod(remainder, 60)
106
+ seconds, milliseconds = divmod(remainder, 1)
107
+ milliseconds *= 1000
108
+ return int(hours), int(minutes), int(seconds), int(milliseconds)
109
+
110
+
111
+ def frames_to_hmsms(frames: int, frame_rate: int = 30) -> tuple[int, int, int, int]:
112
+ seconds = frames / frame_rate
113
+ return seconds_to_hmsms(seconds)
hymotion/utils/motion_process.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def smooth_quats(quats: np.ndarray, sigma: float = 1.0) -> np.ndarray:
9
+ from .geometry import gaussian_kernel1d, quaternion_fix_continuity, slice_seq_with_padding, wavg_quaternion_markley
10
+
11
+ if len(quats) == 0 or sigma <= 0:
12
+ return quats.copy()
13
+
14
+ q_all = quaternion_fix_continuity(torch.from_numpy(quats)).numpy()
15
+
16
+ results = q_all.copy()
17
+ truncate = 4.0
18
+ order = 0
19
+ lw = int(truncate * float(sigma) + 0.5)
20
+ weights = gaussian_kernel1d(sigma=sigma, order=order, radius=lw)[::-1]
21
+ kernel_len = len(weights)
22
+
23
+ for fr in range(len(q_all)):
24
+ cur_quats = slice_seq_with_padding(q_all, fr, kernel_len) # (K,4)
25
+ ref = cur_quats[kernel_len // 2 : kernel_len // 2 + 1] # (1,4)
26
+ dots = (cur_quats * ref).sum(axis=-1, keepdims=True) # (K,1)
27
+ cur_quats = np.where(dots < 0.0, -cur_quats, cur_quats)
28
+
29
+ results[fr, :] = wavg_quaternion_markley(cur_quats, weights)
30
+
31
+ return results.copy()
32
+
33
+
34
+ def smooth_rotation(
35
+ quats: np.ndarray,
36
+ # joint_names: List[str],
37
+ # smooth_joints: List[str],
38
+ sigma: float = 1.0,
39
+ ) -> np.ndarray:
40
+ from .geometry import quaternion_fix_continuity
41
+
42
+ if quats.ndim == 4:
43
+ is_batch = True
44
+ else:
45
+ is_batch = False
46
+ quats = quats[None, ...]
47
+ for b in range(quats.shape[0]):
48
+ for j_idx in range(quats.shape[2]):
49
+ cur_quats = quats[b, :, j_idx].copy()
50
+ cur_quats_t = quaternion_fix_continuity(torch.from_numpy(cur_quats)).numpy()
51
+ quats[b, :, j_idx] = smooth_quats(cur_quats_t, sigma=sigma)
52
+ if not is_batch:
53
+ quats = quats.squeeze(0)
54
+ return quats
55
+
56
+
57
+ def unwrap_euler_over_time(xyz: torch.Tensor) -> torch.Tensor:
58
+ # xyz: (B, L, J, 3)
59
+ # y[t] = y[0] + cumsum(wrap(Δy))
60
+ y = xyz.clone()
61
+ dy = torch.atan2(torch.sin(y[:, 1:] - y[:, :-1]), torch.cos(y[:, 1:] - y[:, :-1]))
62
+ y[:, 1:] = y[:, :1] + torch.cumsum(dy, dim=1)
63
+ return y
hymotion/utils/path.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import platform
4
+ from pathlib import Path
5
+ from typing import Any, Generator, List, Optional, Union
6
+
7
+ from .misc import is_str
8
+
9
+ if platform.system() == "Windows":
10
+ import regex as re
11
+ else:
12
+ import re
13
+
14
+
15
+ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
16
+ if not osp.isfile(filename):
17
+ raise FileNotFoundError(msg_tmpl.format(filename))
18
+
19
+
20
+ def mkdir_or_exist(dir_name: str, mode: int = 0o777) -> None:
21
+ if dir_name == "":
22
+ return
23
+ dir_name = osp.expanduser(dir_name)
24
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
25
+
26
+
27
+ def symlink(src: str, dst: str, overwrite: bool = True, **kwargs) -> None:
28
+ if os.path.lexists(dst) and overwrite:
29
+ os.remove(dst)
30
+ os.symlink(src, dst, **kwargs)
31
+
32
+
33
+ def is_filepath(x: Any) -> bool:
34
+ return is_str(x) or isinstance(x, Path)
35
+
36
+
37
+ def scandir(
38
+ dir_path: Union[str, Path],
39
+ suffix: Optional[str] = None,
40
+ recursive: bool = False,
41
+ case_sensitive: bool = True,
42
+ ) -> Generator[str, None, None]:
43
+ """Scan a directory to find the interested files.
44
+
45
+ Args:
46
+ dir_path (str | :obj:`Path`): Path of the directory.
47
+ suffix (str | tuple(str), optional): File suffix that we are
48
+ interested in. Default: None.
49
+ recursive (bool, optional): If set to True, recursively scan the
50
+ directory. Default: False.
51
+ case_sensitive (bool, optional) : If set to False, ignore the case of
52
+ suffix. Default: True.
53
+ Returns:
54
+ A generator for all the interested files with relative paths.
55
+ """
56
+ if isinstance(dir_path, (str, Path)):
57
+ dir_path = str(dir_path)
58
+ else:
59
+ raise TypeError('"dir_path" must be a string or Path object')
60
+
61
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
62
+ raise TypeError('"suffix" must be a string or tuple of strings')
63
+
64
+ if suffix is not None and not case_sensitive:
65
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix)
66
+
67
+ root = dir_path
68
+
69
+ def _scandir(
70
+ dir_path: Union[str, Path],
71
+ suffix: Optional[str],
72
+ recursive: bool,
73
+ case_sensitive: bool,
74
+ ) -> Generator[str, None, None]:
75
+ for entry in os.scandir(dir_path):
76
+ if not entry.name.startswith(".") and entry.is_file():
77
+ rel_path = osp.relpath(entry.path, root)
78
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
79
+ if suffix is None or _rel_path.endswith(suffix):
80
+ yield rel_path
81
+ elif recursive and os.path.isdir(entry.path):
82
+ # scan recursively if entry.path is a directory
83
+ yield from _scandir(entry.path, suffix, recursive, case_sensitive)
84
+
85
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
86
+
87
+
88
+ def find_files(directory, pattern, recursive=True, abspath=False) -> List[str]:
89
+ regex = re.compile(pattern)
90
+ file_list = []
91
+ for root, _, files in os.walk(directory):
92
+ for f in files:
93
+ if regex.match(f) is not None:
94
+ file_list.append(os.path.join(root, f))
95
+ if not recursive:
96
+ break
97
+ map_func = os.path.abspath if abspath else os.path.relpath
98
+ return list(map(map_func, sorted(file_list)))
99
+
100
+
101
+ def natural_keys(text: str, retoken: str = r"[a-zA-Z]*(\d+)[a-zA-Z_]*[\.].*", n: int = 1) -> Union[int, str]:
102
+ def _atoi(text: str) -> Union[int, str]:
103
+ return int(text) if text.isdigit() else text.lower()
104
+
105
+ return _atoi(re.split(retoken, text)[n])
106
+
107
+
108
+ listdirs = lambda root: [osp.join(base, d) for base, dirs, _ in os.walk(root) if dirs for d in dirs]
109
+
110
+ listfiles = lambda root: [f for base, _, files in os.walk(root) if files for f in files]
111
+
112
+
113
+ def parse_dirs_and_sort(
114
+ input_dirs: Union[list, str],
115
+ suffix: str,
116
+ is_sort: bool = False,
117
+ with_prefix: bool = True,
118
+ ) -> List[str]:
119
+ if isinstance(input_dirs, list):
120
+ input_dirs_list = []
121
+ for iter_input_dir in input_dirs:
122
+ if osp.isdir(iter_input_dir):
123
+ input_dirs_list += [
124
+ osp.join(iter_input_dir, x) if with_prefix else x
125
+ for x in scandir(
126
+ iter_input_dir,
127
+ suffix=suffix,
128
+ recursive=True,
129
+ case_sensitive=False,
130
+ )
131
+ ]
132
+ elif osp.isfile(iter_input_dir):
133
+ if iter_input_dir.endswith(suffix):
134
+ input_dirs_list += [iter_input_dir]
135
+ else:
136
+ raise ValueError(f"Input path {iter_input_dir} is not exist.")
137
+ elif isinstance(input_dirs, str):
138
+ if osp.isdir(input_dirs):
139
+ input_dirs_list = [
140
+ osp.join(input_dirs, x) if with_prefix else x
141
+ for x in scandir(input_dirs, suffix=suffix, recursive=True, case_sensitive=False)
142
+ ]
143
+ elif osp.isfile(input_dirs):
144
+ if input_dirs.endswith(suffix):
145
+ input_dirs_list = [input_dirs]
146
+ else:
147
+ input_dirs_list = []
148
+ else:
149
+ raise ValueError(f"Input path {input_dirs} is not exist.")
150
+ else:
151
+ raise ValueError("Only support list or str input.")
152
+
153
+ if is_sort:
154
+ try:
155
+ try:
156
+ input_dirs_list = sorted(
157
+ input_dirs_list,
158
+ key=lambda text: (
159
+ natural_keys(text, retoken=r"[a-zA-Z]*(\d+)_[0-9a-zA-Z_]*[\.].*", n=1),
160
+ natural_keys(text, retoken=r"[0-9a-zA-Z]*_(\d+)[a-zA-Z_]*[\.].*", n=1),
161
+ ),
162
+ )
163
+ except:
164
+ input_dirs_list = sorted(input_dirs_list, key=lambda text: (natural_keys(text)))
165
+ except:
166
+ input_dirs_list = sorted(input_dirs_list, key=lambda text: text)
167
+
168
+ return input_dirs_list
hymotion/utils/smplh2woodfbx.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ from typing import Dict, Optional
6
+
7
+ import fbx
8
+ import numpy as np
9
+ import torch
10
+ from transforms3d.euler import mat2euler
11
+
12
+ from .geometry import angle_axis_to_rotation_matrix, rot6d_to_rotation_matrix, rotation_matrix_to_angle_axis
13
+
14
+ # yapf: disable
15
+ SMPLH_JOINT2NUM = {
16
+ "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3,
17
+ "L_Knee": 4, "R_Knee": 5, "Spine2": 6,
18
+ "L_Ankle": 7, "R_Ankle": 8,
19
+ "Spine3": 9,
20
+ "L_Foot": 10, "R_Foot": 11,
21
+ "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15,
22
+ "L_Shoulder": 16, "R_Shoulder": 17,
23
+ "L_Elbow": 18, "R_Elbow": 19,
24
+ "L_Wrist": 20, "R_Wrist": 21,
25
+ "L_Index1": 22, "L_Index2": 23, "L_Index3": 24,
26
+ "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27,
27
+ "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30,
28
+ "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33,
29
+ "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36,
30
+ "R_Index1": 37, "R_Index2": 38, "R_Index3": 39,
31
+ "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42,
32
+ "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45,
33
+ "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48,
34
+ "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51,
35
+ }
36
+
37
+ # Mapping from SMPL-H joint names to lowercase names used in some FBX templates
38
+ SMPLH_TO_LOWERCASE_MAPPING = {
39
+ "Pelvis": "pelvis",
40
+ "L_Hip": "left_hip",
41
+ "R_Hip": "right_hip",
42
+ "Spine1": "spine1",
43
+ "L_Knee": "left_knee",
44
+ "R_Knee": "right_knee",
45
+ "Spine2": "spine2",
46
+ "L_Ankle": "left_ankle",
47
+ "R_Ankle": "right_ankle",
48
+ "Spine3": "spine3",
49
+ "L_Foot": "left_foot",
50
+ "R_Foot": "right_foot",
51
+ "Neck": "neck",
52
+ "L_Collar": "left_collar",
53
+ "R_Collar": "right_collar",
54
+ "Head": "head",
55
+ "L_Shoulder": "left_shoulder",
56
+ "R_Shoulder": "right_shoulder",
57
+ "L_Elbow": "left_elbow",
58
+ "R_Elbow": "right_elbow",
59
+ "L_Wrist": "left_wrist",
60
+ "R_Wrist": "right_wrist",
61
+ "L_Index1": "left_index1",
62
+ "L_Index2": "left_index2",
63
+ "L_Index3": "left_index3",
64
+ "L_Middle1": "left_middle1",
65
+ "L_Middle2": "left_middle2",
66
+ "L_Middle3": "left_middle3",
67
+ "L_Pinky1": "left_pinky1",
68
+ "L_Pinky2": "left_pinky2",
69
+ "L_Pinky3": "left_pinky3",
70
+ "L_Ring1": "left_ring1",
71
+ "L_Ring2": "left_ring2",
72
+ "L_Ring3": "left_ring3",
73
+ "L_Thumb1": "left_thumb1",
74
+ "L_Thumb2": "left_thumb2",
75
+ "L_Thumb3": "left_thumb3",
76
+ "R_Index1": "right_index1",
77
+ "R_Index2": "right_index2",
78
+ "R_Index3": "right_index3",
79
+ "R_Middle1": "right_middle1",
80
+ "R_Middle2": "right_middle2",
81
+ "R_Middle3": "right_middle3",
82
+ "R_Pinky1": "right_pinky1",
83
+ "R_Pinky2": "right_pinky2",
84
+ "R_Pinky3": "right_pinky3",
85
+ "R_Ring1": "right_ring1",
86
+ "R_Ring2": "right_ring2",
87
+ "R_Ring3": "right_ring3",
88
+ "R_Thumb1": "right_thumb1",
89
+ "R_Thumb2": "right_thumb2",
90
+ "R_Thumb3": "right_thumb3",
91
+ }
92
+ # yapf: enable
93
+
94
+
95
+ def _loadFbxScene(fbxManager, filepath):
96
+ """Load an FBX file into a scene"""
97
+ importer = fbx.FbxImporter.Create(fbxManager, "")
98
+
99
+ if not importer.Initialize(filepath, -1, fbxManager.GetIOSettings()):
100
+ raise Exception(
101
+ f"Failed to initialize FBX importer for: {filepath}\nError: {importer.GetStatus().GetErrorString()}"
102
+ )
103
+
104
+ fbxScene = fbx.FbxScene.Create(fbxManager, "")
105
+ importer.Import(fbxScene)
106
+ importer.Destroy()
107
+
108
+ return fbxScene
109
+
110
+
111
+ def _collectAllNodes(node, nodes_dict=None):
112
+ """Recursively collect all nodes in the scene hierarchy"""
113
+ if nodes_dict is None:
114
+ nodes_dict = {}
115
+
116
+ nodes_dict[node.GetName()] = node
117
+
118
+ for i in range(node.GetChildCount()):
119
+ _collectAllNodes(node.GetChild(i), nodes_dict)
120
+
121
+ return nodes_dict
122
+
123
+
124
+ def _collectSkeletonNodes(node, skeleton_nodes=None):
125
+ """Recursively collect skeleton/bone nodes"""
126
+ if skeleton_nodes is None:
127
+ skeleton_nodes = {}
128
+
129
+ # Check if this node has a skeleton attribute
130
+ attr = node.GetNodeAttribute()
131
+ if attr and attr.GetAttributeType() == fbx.FbxNodeAttribute.EType.eSkeleton:
132
+ skeleton_nodes[node.GetName()] = node
133
+
134
+ for i in range(node.GetChildCount()):
135
+ _collectSkeletonNodes(node.GetChild(i), skeleton_nodes)
136
+
137
+ return skeleton_nodes
138
+
139
+
140
+ def _animateSingleChannel(animLayer, component, name, values, frameDuration):
141
+ """Animate a single channel (X, Y, or Z) with keyframes"""
142
+ ncomp = {"X": 0, "Y": 1, "Z": 2}.get(name, 0)
143
+
144
+ time = fbx.FbxTime()
145
+ curve = component.GetCurve(animLayer, name, True)
146
+ curve.KeyModifyBegin()
147
+ for nth in range(len(values)):
148
+ time.SetSecondDouble(nth * frameDuration)
149
+ keyIndex = curve.KeyAdd(time)[0]
150
+ curve.KeySetValue(keyIndex, values[nth][ncomp])
151
+ curve.KeySetInterpolation(keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant)
152
+ curve.KeyModifyEnd()
153
+
154
+
155
+ def _animateRotationKeyFrames(animLayer, node, rot_matrices, frameDuration):
156
+ """Animate rotation keyframes for a node using rotation matrices"""
157
+ rotations = []
158
+ for nth in range(len(rot_matrices)):
159
+ # Convert rotation matrix to Euler angles (XYZ order)
160
+ euler = np.rad2deg(mat2euler(rot_matrices[nth], axes="sxyz"))
161
+ rotations.append(euler)
162
+
163
+ _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration)
164
+ _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration)
165
+ _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration)
166
+
167
+
168
+ def _animateTranslationKeyFrames(animLayer, node, translations, frameDuration):
169
+ """Animate translation keyframes for a node"""
170
+ # Ensure translations is a numpy array with shape (num_frames, 3)
171
+ if isinstance(translations, torch.Tensor):
172
+ translations = translations.numpy()
173
+ translations = np.asarray(translations, dtype=np.float64)
174
+
175
+ if len(translations.shape) == 1:
176
+ # Single frame, reshape to (1, 3)
177
+ translations = translations.reshape(1, -1)
178
+
179
+ _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration)
180
+ _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration)
181
+ _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration)
182
+
183
+
184
+ def _clearExistingAnimations(fbxScene):
185
+ """Remove all existing animation stacks from the scene"""
186
+ anim_stack_count = fbxScene.GetSrcObjectCount(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId))
187
+ for i in range(anim_stack_count - 1, -1, -1):
188
+ anim_stack = fbxScene.GetSrcObject(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId), i)
189
+ if anim_stack:
190
+ anim_stack.Destroy()
191
+
192
+
193
+ def _applyAnimationToSkeleton(fbxScene, nodes_map, rot_matrices, translations, fps, smplh_to_fbx_mapping, name="Take1"):
194
+ """
195
+ Apply SMPL-H animation data to skeleton nodes in the FBX scene.
196
+
197
+ Args:
198
+ fbxScene: FBX scene object
199
+ nodes_map: Dictionary of node_name -> FbxNode
200
+ rot_matrices: (num_frames, num_joints, 3, 3) rotation matrices
201
+ translations: (num_frames, 3) root translations (relative displacement, not absolute position)
202
+ fps: Frame rate
203
+ smplh_to_fbx_mapping: Mapping from SMPL-H joint names to FBX node names
204
+ name: Animation take name
205
+ """
206
+ frameDuration = 1.0 / fps
207
+ num_frames = rot_matrices.shape[0]
208
+ num_joints = rot_matrices.shape[1]
209
+
210
+ # Create animation stack and layer
211
+ animStack = fbx.FbxAnimStack.Create(fbxScene, name)
212
+ animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer")
213
+ animStack.AddMember(animLayer)
214
+
215
+ # Track if root translation was applied
216
+ root_translation_applied = False
217
+ root_node = None
218
+
219
+ # Get root node's initial LclTranslation from template (this is like Translates[0] in smplh2woodfbx.py)
220
+ root_initial_translation = None
221
+ root_fbx_name = smplh_to_fbx_mapping.get("Pelvis")
222
+ if root_fbx_name and root_fbx_name in nodes_map:
223
+ root_node_temp = nodes_map[root_fbx_name]
224
+ initial_trans = root_node_temp.LclTranslation.Get()
225
+ root_initial_translation = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
226
+ print(f"Root initial LclTranslation from template: {root_initial_translation}")
227
+
228
+ # Animate each joint
229
+ for smplh_joint_name, smplh_joint_idx in SMPLH_JOINT2NUM.items():
230
+ if smplh_joint_idx >= num_joints:
231
+ continue
232
+
233
+ # Get the FBX node name from mapping
234
+ fbx_node_name = smplh_to_fbx_mapping.get(smplh_joint_name)
235
+ if not fbx_node_name:
236
+ if smplh_joint_idx == 0:
237
+ print(f"Warning: Root joint 'Pelvis' not found in mapping!")
238
+ continue
239
+
240
+ # Find the node
241
+ node = nodes_map.get(fbx_node_name)
242
+ if not node:
243
+ print(f"Warning: Joint '{smplh_joint_name}' (FBX: '{fbx_node_name}') not found in scene")
244
+ continue
245
+
246
+ # Animate rotation
247
+ _animateRotationKeyFrames(
248
+ animLayer=animLayer,
249
+ node=node,
250
+ rot_matrices=rot_matrices[:, smplh_joint_idx],
251
+ frameDuration=frameDuration,
252
+ )
253
+
254
+ # Animate translation for root joint (Pelvis)
255
+ if smplh_joint_idx == 0:
256
+ root_node = node
257
+ # Add initial offset to translations (like smplh2woodfbx.py does: Translates[0] + trans)
258
+ # The translations input is relative displacement, we need to add the template's initial position
259
+ if root_initial_translation is not None:
260
+ final_translations = translations + root_initial_translation
261
+ print(
262
+ f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
263
+ f"initial_offset={root_initial_translation}, "
264
+ f"final translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
265
+ )
266
+ else:
267
+ final_translations = translations
268
+ print(
269
+ f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
270
+ f"translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
271
+ )
272
+ _animateTranslationKeyFrames(
273
+ animLayer=animLayer,
274
+ node=node,
275
+ translations=final_translations,
276
+ frameDuration=frameDuration,
277
+ )
278
+ root_translation_applied = True
279
+
280
+ # If root translation was not applied, try to find root node by common names
281
+ if not root_translation_applied:
282
+ print("Warning: Root translation was not applied through normal mapping, trying fallback...")
283
+ root_candidates = ["Pelvis", "pelvis", "Hips", "hips", "Root", "root", "mixamorig:Hips"]
284
+ for candidate in root_candidates:
285
+ if candidate in nodes_map:
286
+ root_node = nodes_map[candidate]
287
+ # Get initial translation for fallback node
288
+ initial_trans = root_node.LclTranslation.Get()
289
+ fallback_initial = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
290
+ final_translations = translations + fallback_initial
291
+ print(
292
+ f"Found root node by fallback: '{candidate}', initial_offset={fallback_initial}, applying translation..."
293
+ )
294
+ _animateTranslationKeyFrames(
295
+ animLayer=animLayer,
296
+ node=root_node,
297
+ translations=final_translations,
298
+ frameDuration=frameDuration,
299
+ )
300
+ root_translation_applied = True
301
+ break
302
+
303
+ if not root_translation_applied:
304
+ print("ERROR: Could not find root node to apply translation!")
305
+ print(f"Available nodes: {list(nodes_map.keys())}")
306
+
307
+ return animStack
308
+
309
+
310
+ def _saveScene(filename, fbxManager, fbxScene, embed_textures=True):
311
+ """Save the FBX scene to a file
312
+
313
+ Args:
314
+ filename: Output file path
315
+ fbxManager: FBX manager instance
316
+ fbxScene: FBX scene to save
317
+ embed_textures: Whether to embed textures/media in the FBX file (default True)
318
+ """
319
+ # Configure IOSettings to embed textures/media
320
+ ios = fbxManager.GetIOSettings()
321
+ if embed_textures:
322
+ ios.SetBoolProp(fbx.EXP_FBX_EMBEDDED, True)
323
+ ios.SetBoolProp(fbx.EXP_FBX_MATERIAL, True)
324
+ ios.SetBoolProp(fbx.EXP_FBX_TEXTURE, True)
325
+
326
+ exporter = fbx.FbxExporter.Create(fbxManager, "")
327
+ isInitialized = exporter.Initialize(filename, -1, ios)
328
+
329
+ if isInitialized is False:
330
+ raise Exception(f"Exporter failed to initialize. Error: {exporter.GetStatus().GetErrorString()}")
331
+
332
+ exporter.Export(fbxScene)
333
+ exporter.Destroy()
334
+
335
+
336
+ def _convert_smplh_to_woodfbx(
337
+ template_fbx_path,
338
+ npz_data,
339
+ save_fn,
340
+ fps=30,
341
+ scale=100,
342
+ smplh_to_fbx_mapping=None,
343
+ clear_animations=True,
344
+ ):
345
+ """
346
+ Convert SMPL-H parameters to FBX using a template FBX file.
347
+ The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
348
+
349
+ Args:
350
+ template_fbx_path: Path to the template FBX file (e.g., boy_Rigging_smplx.fbx)
351
+ npz_data: Dictionary containing SMPL-H parameters
352
+ - poses: (num_frames, 52, 3) or (num_frames, 156)
353
+ - trans: (num_frames, 3)
354
+ save_fn: Output FBX file path
355
+ fps: Frame rate
356
+ scale: Scale factor for translation (default 100 for m to cm conversion)
357
+ smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
358
+ clear_animations: Whether to clear existing animations in the template
359
+
360
+ Returns:
361
+ bool: True if successful
362
+ """
363
+ # Prepare poses data
364
+ poses = npz_data["poses"]
365
+ if isinstance(poses, np.ndarray):
366
+ poses = torch.from_numpy(poses).float()
367
+
368
+ if len(poses.shape) == 2:
369
+ # (num_frames, 156) -> (num_frames, 52, 3)
370
+ poses = poses.reshape(poses.shape[0], -1, 3)
371
+
372
+ # Convert axis-angle to rotation matrices: (num_frames, num_joints, 3, 3)
373
+ rot_matrices = angle_axis_to_rotation_matrix(poses).numpy()
374
+
375
+ # Prepare translation data
376
+ trans = npz_data["trans"]
377
+ if isinstance(trans, torch.Tensor):
378
+ trans = trans.numpy()
379
+
380
+ # Apply scale to translation
381
+ translations = trans * scale
382
+
383
+ # Create FBX manager and load template
384
+ fbxManager = fbx.FbxManager.Create()
385
+ ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
386
+ fbxManager.SetIOSettings(ios)
387
+
388
+ print(f"Loading FBX template: {template_fbx_path}")
389
+ fbxScene = _loadFbxScene(fbxManager, template_fbx_path)
390
+
391
+ # Set time mode
392
+ timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps)
393
+ fbxScene.GetGlobalSettings().SetTimeMode(timeMode)
394
+
395
+ # Collect all nodes
396
+ rootNode = fbxScene.GetRootNode()
397
+ all_nodes = _collectAllNodes(rootNode)
398
+ skeleton_nodes = _collectSkeletonNodes(rootNode)
399
+
400
+ print(f"Found {len(all_nodes)} nodes in scene")
401
+ print(f"Found {len(skeleton_nodes)} skeleton nodes: {list(skeleton_nodes.keys())}")
402
+
403
+ # Use default mapping if not provided
404
+ if smplh_to_fbx_mapping is None:
405
+ smplh_to_fbx_mapping = _auto_detect_mapping(all_nodes)
406
+ print(f"Auto-detected {len(smplh_to_fbx_mapping)} joint mappings")
407
+ if "Pelvis" in smplh_to_fbx_mapping:
408
+ print(f" Root joint 'Pelvis' mapped to: '{smplh_to_fbx_mapping['Pelvis']}'")
409
+ else:
410
+ print(f" WARNING: Root joint 'Pelvis' not found in mapping!")
411
+ print(f" Available nodes: {list(all_nodes.keys())[:20]}...") # Show first 20 nodes
412
+
413
+ # Clear existing animations if requested
414
+ if clear_animations:
415
+ _clearExistingAnimations(fbxScene)
416
+
417
+ # Apply animation to skeleton
418
+ _applyAnimationToSkeleton(
419
+ fbxScene=fbxScene,
420
+ nodes_map=all_nodes,
421
+ rot_matrices=rot_matrices,
422
+ translations=translations,
423
+ fps=fps,
424
+ smplh_to_fbx_mapping=smplh_to_fbx_mapping,
425
+ name="SMPLH_Animation",
426
+ )
427
+
428
+ # Save to temporary file first, then copy to final destination
429
+ os.makedirs(os.path.dirname(save_fn) if os.path.dirname(save_fn) else ".", exist_ok=True)
430
+ with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f:
431
+ temp_file = tmp_f.name
432
+
433
+ try:
434
+ _saveScene(temp_file, fbxManager, fbxScene)
435
+ shutil.copy2(temp_file, save_fn)
436
+ os.remove(temp_file)
437
+ print(f"Successfully saved FBX to: {save_fn}")
438
+ except Exception as e:
439
+ print(f"Error saving FBX file: {e}")
440
+ return False
441
+ finally:
442
+ fbxManager.Destroy()
443
+ del fbxManager, fbxScene
444
+
445
+ return os.path.exists(save_fn)
446
+
447
+
448
+ def _auto_detect_mapping(all_nodes):
449
+ """Auto-detect the mapping from SMPL-H joints to FBX nodes"""
450
+ mapping = {}
451
+ for smplh_name in SMPLH_JOINT2NUM.keys():
452
+ # Try exact match
453
+ if smplh_name in all_nodes:
454
+ mapping[smplh_name] = smplh_name
455
+ # Try lowercase version
456
+ elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in all_nodes:
457
+ mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
458
+ return mapping
459
+
460
+
461
+ class SMPLH2WoodFBX:
462
+ """
463
+ Class to convert SMPL-H parameters to FBX using a template FBX file.
464
+ The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
465
+ No SMPL-H model assets (model.npz) required.
466
+
467
+ Example usage:
468
+ converter = SMPLH2WoodFBX(
469
+ template_fbx_path="./assets/wooden_models/boy_Rigging_smplx.fbx"
470
+ )
471
+
472
+ # From npz file
473
+ converter.convert_npz_to_fbx("motion.npz", "output.fbx", fps=30)
474
+
475
+ # From parameters dict
476
+ params = {
477
+ "poses": poses_array, # (num_frames, 52, 3) or (num_frames, 156)
478
+ "trans": trans_array, # (num_frames, 3)
479
+ }
480
+ converter.convert_params_to_fbx(params, "output.fbx")
481
+ """
482
+
483
+ def __init__(
484
+ self,
485
+ template_fbx_path: str = "./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
486
+ smplh_to_fbx_mapping: Optional[Dict[str, str]] = None,
487
+ scale: float = 100,
488
+ ):
489
+ """
490
+ Initialize the converter.
491
+
492
+ Args:
493
+ template_fbx_path: Path to the template FBX file
494
+ smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
495
+ scale: Scale factor for translation (default 100 for m to cm conversion)
496
+ """
497
+ print(f"[{self.__class__.__name__}] Template FBX: {template_fbx_path}")
498
+ self.template_fbx_path = template_fbx_path
499
+ self.smplh_to_fbx_mapping = smplh_to_fbx_mapping
500
+ self.scale = scale
501
+
502
+ # Analyze template FBX to detect joint names
503
+ self._analyze_template()
504
+
505
+ def _analyze_template(self):
506
+ """Analyze the template FBX file to detect available skeleton nodes"""
507
+ fbxManager = fbx.FbxManager.Create()
508
+ ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
509
+ fbxManager.SetIOSettings(ios)
510
+
511
+ try:
512
+ fbxScene = _loadFbxScene(fbxManager, self.template_fbx_path)
513
+ rootNode = fbxScene.GetRootNode()
514
+
515
+ self.all_template_nodes = list(_collectAllNodes(rootNode).keys())
516
+ self.skeleton_template_nodes = list(_collectSkeletonNodes(rootNode).keys())
517
+
518
+ print(f"[{self.__class__.__name__}] Template nodes: {len(self.all_template_nodes)}")
519
+ print(f"[{self.__class__.__name__}] Skeleton nodes: {self.skeleton_template_nodes}")
520
+
521
+ # Auto-detect mapping if not provided
522
+ if self.smplh_to_fbx_mapping is None:
523
+ self.smplh_to_fbx_mapping = self._auto_detect_mapping()
524
+ print(f"[{self.__class__.__name__}] Auto-detected {len(self.smplh_to_fbx_mapping)} joint mappings")
525
+ finally:
526
+ fbxManager.Destroy()
527
+
528
+ def _auto_detect_mapping(self):
529
+ """Auto-detect the mapping from SMPL-H joints to FBX nodes"""
530
+ mapping = {}
531
+ for smplh_name in SMPLH_JOINT2NUM.keys():
532
+ # Try exact match
533
+ if smplh_name in self.all_template_nodes:
534
+ mapping[smplh_name] = smplh_name
535
+ # Try lowercase version
536
+ elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in self.all_template_nodes:
537
+ mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
538
+ return mapping
539
+
540
+ def convert_npz_to_fbx(self, npz_file, outname, fps=30, clear_animations=True):
541
+ """
542
+ Convert an npz file containing SMPL-H parameters to FBX.
543
+
544
+ Args:
545
+ npz_file: Path to the npz file or dict containing SMPL-H parameters
546
+ outname: Output FBX file path
547
+ fps: Frame rate
548
+ clear_animations: Whether to clear existing animations in template
549
+
550
+ Returns:
551
+ bool: True if successful
552
+ """
553
+ os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
554
+
555
+ if isinstance(npz_file, str) and os.path.isfile(npz_file):
556
+ npz_data = dict(np.load(npz_file, allow_pickle=True))
557
+ else:
558
+ npz_data = npz_file
559
+
560
+ return _convert_smplh_to_woodfbx(
561
+ template_fbx_path=self.template_fbx_path,
562
+ npz_data=npz_data,
563
+ save_fn=outname,
564
+ fps=fps,
565
+ scale=self.scale,
566
+ smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
567
+ clear_animations=clear_animations,
568
+ )
569
+
570
+ def convert_params_to_fbx(self, params, outname, clear_animations=True):
571
+ """
572
+ Convert SMPL-H parameters to FBX.
573
+
574
+ Args:
575
+ params: Dictionary containing SMPL-H parameters
576
+ - poses: (num_frames, 52, 3) or (num_frames, 156)
577
+ - trans: (num_frames, 3)
578
+ - mocap_framerate (optional): Frame rate
579
+ outname: Output FBX file path
580
+ clear_animations: Whether to clear existing animations in template
581
+
582
+ Returns:
583
+ bool: True if successful
584
+ """
585
+ fps = params.get("mocap_framerate", 30)
586
+ os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
587
+
588
+ npz_data = {
589
+ "poses": params["poses"],
590
+ "trans": params["trans"],
591
+ }
592
+
593
+ return _convert_smplh_to_woodfbx(
594
+ template_fbx_path=self.template_fbx_path,
595
+ npz_data=npz_data,
596
+ save_fn=outname,
597
+ fps=fps,
598
+ scale=self.scale,
599
+ smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
600
+ clear_animations=clear_animations,
601
+ )
602
+
603
+
604
+ if __name__ == "__main__":
605
+ # python hymotion/utils/smplh2woodfbx.py
606
+ import argparse
607
+
608
+ parser = argparse.ArgumentParser()
609
+ parser.add_argument("root", type=str)
610
+ args = parser.parse_args()
611
+
612
+ converter = SMPLH2WoodFBX(
613
+ template_fbx_path="./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
614
+ scale=100,
615
+ )
616
+
617
+ if os.path.isdir(args.root):
618
+ npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz")))
619
+ else:
620
+ if args.root.endswith(".npz"):
621
+ npzfiles = [args.root]
622
+ else:
623
+ raise ValueError(f"Unknown file type: {args.root}")
624
+
625
+ for npzfile in npzfiles:
626
+ converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx"))
hymotion/utils/t2m_runtime.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # t2m_runtime.py
2
+ import os
3
+ import threading
4
+ import time
5
+ import uuid
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import yaml
10
+
11
+ from ..prompt_engineering.prompt_rewrite import PromptRewriter
12
+ from .loaders import load_object
13
+ from .visualize_mesh_web import save_visualization_data, generate_static_html_content
14
+
15
+ try:
16
+ import fbx
17
+
18
+ FBX_AVAILABLE = True
19
+ print(">>> FBX module found.")
20
+ except ImportError:
21
+ FBX_AVAILABLE = False
22
+ print(">>> FBX module not found.")
23
+
24
+
25
+ def _get_local_ip():
26
+ import subprocess
27
+
28
+ result = subprocess.run(["hostname", "-I"], capture_output=True, text=True, timeout=5)
29
+ if result.returncode == 0:
30
+ for ip in result.stdout.strip().split():
31
+ if not ip.startswith("127.") and not ip.startswith("172.17."):
32
+ return ip
33
+ return "localhost"
34
+
35
+
36
+ def _now():
37
+ t = time.time()
38
+ ms = int((t - int(t)) * 1000)
39
+ return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
40
+
41
+
42
+ class T2MRuntime:
43
+ def __init__(
44
+ self,
45
+ config_path: str,
46
+ ckpt_name: str = "latest.ckpt",
47
+ skip_text: bool = False,
48
+ device_ids: Union[list[int], None] = None,
49
+ prompt_engineering_host: Optional[str] = None,
50
+ skip_model_loading: bool = False,
51
+ force_cpu: bool = False,
52
+ ):
53
+ self.config_path = config_path
54
+ self.ckpt_name = ckpt_name
55
+ self.skip_text = skip_text
56
+ self.prompt_engineering_host = prompt_engineering_host
57
+ self.skip_model_loading = skip_model_loading
58
+ self.local_ip = _get_local_ip()
59
+
60
+ if force_cpu:
61
+ print(">>> [INFO] CPU mode enabled via HY_MOTION_DEVICE=cpu environment variable")
62
+ self.device_ids = []
63
+ elif torch.cuda.is_available():
64
+ all_ids = list(range(torch.cuda.device_count()))
65
+ self.device_ids = all_ids if device_ids is None else [i for i in device_ids if i in all_ids]
66
+ else:
67
+ self.device_ids = []
68
+
69
+ self.pipelines = []
70
+ self._gpu_load = []
71
+ self._lock = threading.Lock()
72
+ self._loaded = False
73
+
74
+ self.prompt_rewriter = PromptRewriter(host=self.prompt_engineering_host)
75
+ # Skip model loading if checkpoint not found
76
+ if self.skip_model_loading:
77
+ print(">>> [WARNING] Checkpoint not found, will use randomly initialized model weights")
78
+ self.load()
79
+ self.fbx_available = FBX_AVAILABLE
80
+ if self.fbx_available:
81
+ try:
82
+ from .smplh2woodfbx import SMPLH2WoodFBX
83
+
84
+ self.fbx_converter = SMPLH2WoodFBX()
85
+ except Exception as e:
86
+ print(f">>> Failed to initialize FBX converter: {e}")
87
+ self.fbx_available = False
88
+ self.fbx_converter = None
89
+ else:
90
+ self.fbx_converter = None
91
+ print(">>> FBX module not found. FBX export will be disabled.")
92
+
93
+ device_info = self.device_ids if self.device_ids else "cpu"
94
+ if self.skip_model_loading:
95
+ print(f">>> T2MRuntime initialized (using randomly initialized weights) in IP {self.local_ip}, devices={device_info}")
96
+ else:
97
+ print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}")
98
+
99
+ def load(self):
100
+ if self._loaded:
101
+ return
102
+ print(f">>> Loading model from {self.config_path}...")
103
+
104
+ with open(self.config_path, "r") as f:
105
+ config = yaml.load(f, Loader=yaml.FullLoader)
106
+
107
+ # Use allow_empty_ckpt=True when skip_model_loading is True
108
+ allow_empty_ckpt = self.skip_model_loading
109
+
110
+ if not self.device_ids:
111
+ pipeline = load_object(
112
+ config["train_pipeline"],
113
+ config["train_pipeline_args"],
114
+ network_module=config["network_module"],
115
+ network_module_args=config["network_module_args"],
116
+ )
117
+ device = torch.device("cpu")
118
+ pipeline.load_in_demo(
119
+ self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt
120
+ )
121
+ pipeline.to(device)
122
+ self.pipelines = [pipeline]
123
+ self._gpu_load = [0]
124
+ else:
125
+ for gid in self.device_ids:
126
+ p = load_object(
127
+ config["train_pipeline"],
128
+ config["train_pipeline_args"],
129
+ network_module=config["network_module"],
130
+ network_module_args=config["network_module_args"],
131
+ )
132
+ p.load_in_demo(self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text, allow_empty_ckpt=allow_empty_ckpt)
133
+ p.to(torch.device(f"cuda:{gid}"))
134
+ self.pipelines.append(p)
135
+ self._gpu_load = [0] * len(self.pipelines)
136
+
137
+ self._loaded = True
138
+
139
+ def _acquire_pipeline(self) -> int:
140
+ while True:
141
+ with self._lock:
142
+ for i in range(len(self._gpu_load)):
143
+ if self._gpu_load[i] == 0:
144
+ self._gpu_load[i] = 1
145
+ return i
146
+ time.sleep(0.01)
147
+
148
+ def _release_pipeline(self, idx: int):
149
+ with self._lock:
150
+ self._gpu_load[idx] = 0
151
+
152
+ def test_dit_inference(self, duration: float = 2.0, seed: int = 42) -> bool:
153
+ """
154
+ Test DiT model inference with unconditional/blank input.
155
+ This method is used to verify the DiT model works before loading text encoder.
156
+
157
+ Args:
158
+ duration: Duration of the test motion in seconds
159
+ seed: Random seed for reproducibility
160
+
161
+ Returns:
162
+ True if inference succeeds and produces valid output
163
+ """
164
+ if not self.pipelines:
165
+ raise RuntimeError("No pipeline loaded. Call load() first.")
166
+
167
+ pi = self._acquire_pipeline()
168
+ try:
169
+ pipeline = self.pipelines[pi]
170
+ pipeline.eval()
171
+ device = next(pipeline.parameters()).device
172
+
173
+ # Calculate frame length from duration (assuming 30fps output, 20fps internal)
174
+ length = int(duration * 20)
175
+ length = min(length, pipeline.train_frames)
176
+
177
+ # Use null features for unconditional generation
178
+ batch_size = 1
179
+ vtxt_input = pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device)
180
+ ctxt_input = pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device)
181
+ ctxt_length = torch.tensor([1] * batch_size, device=device)
182
+
183
+ # Create masks
184
+ from ..pipeline.motion_diffusion import length_to_mask
185
+
186
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
187
+ x_length = torch.LongTensor([length] * batch_size).to(device)
188
+ x_mask_temporal = length_to_mask(x_length, pipeline.train_frames)
189
+
190
+ # Run denoising inference
191
+ print(f"\t>>> Running DiT inference test: length={length}, device={device}")
192
+
193
+ # Create random noise
194
+ generator = torch.Generator(device=device).manual_seed(seed)
195
+ latent_shape = (batch_size, pipeline.train_frames, pipeline.mean.shape[-1])
196
+ latents = torch.randn(latent_shape, generator=generator, device=device, dtype=vtxt_input.dtype)
197
+
198
+ # Simple single-step denoising test (just forward pass)
199
+ with torch.no_grad():
200
+ # Get timestep
201
+ timesteps = torch.tensor([0.5], device=device, dtype=vtxt_input.dtype).expand(batch_size)
202
+
203
+ # Forward pass through DiT
204
+ # Use correct parameter names for HunyuanMotionMMDiT.forward()
205
+ _ = pipeline.motion_transformer(
206
+ x=latents,
207
+ ctxt_input=ctxt_input,
208
+ vtxt_input=vtxt_input,
209
+ timesteps=timesteps,
210
+ x_mask_temporal=x_mask_temporal,
211
+ ctxt_mask_temporal=ctxt_mask_temporal,
212
+ )
213
+
214
+ print(f"\t>>> DiT forward pass completed successfully!")
215
+ return True
216
+
217
+ except Exception as e:
218
+ print(f"\t>>> DiT inference test failed: {e}")
219
+ raise
220
+ finally:
221
+ self._release_pipeline(pi)
222
+
223
+ def load_text_encoder(self) -> None:
224
+ """
225
+ Load text encoder for all pipelines.
226
+ This is called after DiT model testing to complete the initialization.
227
+ """
228
+ if not self.pipelines:
229
+ raise RuntimeError("No pipeline loaded. Call load() first.")
230
+
231
+ print(">>> Loading text encoder for all pipelines...")
232
+ for i, pipeline in enumerate(self.pipelines):
233
+ if not hasattr(pipeline, "text_encoder") or pipeline.text_encoder is None:
234
+ device = next(pipeline.parameters()).device
235
+ pipeline.text_encoder = load_object(pipeline._text_encoder_module, pipeline._text_encoder_cfg)
236
+ pipeline.text_encoder.to(device)
237
+ print(f"\t>>> Text encoder loaded for pipeline {i} on {device}")
238
+
239
+ # Update skip_text flag
240
+ self.skip_text = False
241
+ print(">>> Text encoder loading completed!")
242
+
243
+ def rewrite_text_and_infer_time(self, text: str) -> Tuple[float, str]:
244
+ print("Start rewriting text...")
245
+ duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}")
246
+ print(f"\t>>> Rewritten text: {rewritten_text}, duration: {duration:.2f} seconds")
247
+ return duration, rewritten_text
248
+
249
+ def generate_motion(
250
+ self,
251
+ text: str,
252
+ seeds_csv: str,
253
+ duration: float,
254
+ cfg_scale: float,
255
+ output_format: str = "fbx",
256
+ output_dir: Optional[str] = None,
257
+ output_filename: Optional[str] = None,
258
+ original_text: Optional[str] = None,
259
+ use_special_game_feat: bool = False,
260
+ ) -> Tuple[Union[str, list[str]], dict]:
261
+ self.load()
262
+ seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""]
263
+ pi = self._acquire_pipeline()
264
+ try:
265
+ pipeline = self.pipelines[pi]
266
+ pipeline.eval()
267
+
268
+ # When skip_text=True (debug mode), use blank text features
269
+ if self.skip_text:
270
+ print(">>> [Debug Mode] Using blank text features (skip_text=True)")
271
+ device = next(pipeline.parameters()).device
272
+ batch_size = len(seeds) if seeds else 1
273
+ # Create blank hidden_state_dict using null features
274
+ hidden_state_dict = {
275
+ "text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device),
276
+ "text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device),
277
+ "text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device),
278
+ }
279
+ # Disable CFG in debug mode (use cfg_scale=1.0)
280
+ model_output = pipeline.generate(
281
+ text,
282
+ seeds,
283
+ duration,
284
+ cfg_scale=1.0,
285
+ use_special_game_feat=False,
286
+ hidden_state_dict=hidden_state_dict,
287
+ )
288
+ else:
289
+ model_output = pipeline.generate(
290
+ text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat
291
+ )
292
+ finally:
293
+ self._release_pipeline(pi)
294
+
295
+ ts = _now()
296
+ save_data, base_filename = save_visualization_data(
297
+ output=model_output,
298
+ text=text if original_text is None else original_text,
299
+ rewritten_text=text,
300
+ timestamp=ts,
301
+ output_dir=output_dir,
302
+ output_filename=output_filename,
303
+ )
304
+
305
+ html_content = self._generate_html_content(
306
+ timestamp=ts,
307
+ file_path=base_filename,
308
+ output_dir=output_dir,
309
+ )
310
+
311
+ if output_format == "fbx" and not self.fbx_available:
312
+ print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to dict format.")
313
+ output_format = "dict"
314
+
315
+ if output_format == "fbx" and self.fbx_available:
316
+ fbx_files = self._generate_fbx_files(
317
+ visualization_data=save_data,
318
+ output_dir=output_dir,
319
+ fbx_filename=output_filename,
320
+ )
321
+ return html_content, fbx_files, model_output
322
+ elif output_format == "dict":
323
+ # Return HTML content and empty list for fbx_files when using dict format
324
+ return html_content, [], model_output
325
+ else:
326
+ raise ValueError(f">>> Invalid output format: {output_format}")
327
+
328
+ def _generate_html_content(
329
+ self,
330
+ timestamp: str,
331
+ file_path: str,
332
+ output_dir: Optional[str] = None,
333
+ ) -> str:
334
+ """
335
+ Generate static HTML content with embedded data for iframe srcdoc.
336
+ All JavaScript code is embedded directly in the HTML, no external static resources needed.
337
+
338
+ Args:
339
+ timestamp: Timestamp string for logging
340
+ file_path: Base filename (without extension)
341
+ output_dir: Directory where NPZ/meta files are stored
342
+
343
+ Returns:
344
+ HTML content string (to be used in iframe srcdoc)
345
+ """
346
+ print(f">>> Generating static HTML content, timestamp: {timestamp}")
347
+ gradio_dir = output_dir if output_dir is not None else "output/gradio"
348
+
349
+ try:
350
+ # Generate static HTML content with embedded data (all JS is embedded in template)
351
+ html_content = generate_static_html_content(
352
+ folder_name=gradio_dir,
353
+ file_name=file_path,
354
+ hide_captions=False,
355
+ )
356
+
357
+ print(f">>> Static HTML content generated for: {file_path}")
358
+ return html_content
359
+
360
+ except Exception as e:
361
+ print(f">>> Failed to generate static HTML content: {e}")
362
+ import traceback
363
+ traceback.print_exc()
364
+ # Return error HTML
365
+ return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
366
+
367
+ def _generate_fbx_files(
368
+ self,
369
+ visualization_data: dict,
370
+ output_dir: Optional[str] = None,
371
+ fbx_filename: Optional[str] = None,
372
+ ) -> List[str]:
373
+ assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data"
374
+ fbx_files = []
375
+ if output_dir is None:
376
+ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
377
+ output_dir = os.path.join(root_dir, "output", "gradio")
378
+
379
+ smpl_data_list = visualization_data["smpl_data"]
380
+
381
+ unique_id = str(uuid.uuid4())[:8]
382
+ text = visualization_data["text"]
383
+ timestamp = visualization_data["timestamp"]
384
+ for bb in range(len(smpl_data_list)):
385
+ smpl_data = smpl_data_list[bb]
386
+ if fbx_filename is None:
387
+ fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx"
388
+ else:
389
+ fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx"
390
+ fbx_path = os.path.join(output_dir, fbx_filename_bb)
391
+ success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path)
392
+ if success:
393
+ fbx_files.append(fbx_path)
394
+ print(f"\t>>> FBX file generated: {fbx_path}")
395
+ txt_path = fbx_path.replace(".fbx", ".txt")
396
+ with open(txt_path, "w", encoding="utf-8") as f:
397
+ f.write(text)
398
+ fbx_files.append(txt_path)
399
+
400
+ return fbx_files
hymotion/utils/type_converter.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def get_module_device(module: nn.Module) -> torch.device:
6
+ """Get the device of a module.
7
+
8
+ Args:
9
+ module (nn.Module): A module contains the parameters.
10
+
11
+ Returns:
12
+ torch.device: The device of the module.
13
+ """
14
+ try:
15
+ next(module.parameters())
16
+ except StopIteration:
17
+ raise ValueError("The input module should contain parameters.")
18
+
19
+ if next(module.parameters()).is_cuda:
20
+ return torch.device(next(module.parameters()).get_device())
21
+
22
+ return torch.device("cpu")
hymotion/utils/visualize_mesh_web.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import threading
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ _FILE_ACCESS_LOCK = threading.Lock()
12
+
13
+ # Template directory path
14
+ _TEMPLATE_DIR = os.path.join(
15
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
16
+ "scripts", "gradio", "templates"
17
+ )
18
+
19
+
20
+ def sanitize_filename(filename: str) -> str:
21
+ """
22
+ Sanitize filename to prevent path traversal attacks
23
+ Args:
24
+ filename: original filename
25
+ Returns:
26
+ sanitized filename
27
+ """
28
+ if not filename:
29
+ return ""
30
+
31
+ # remove all path traversal characters
32
+ filename = re.sub(r"\.\.(/|\\\\\\)?", "", filename)
33
+ filename = filename.strip("./\\")
34
+
35
+ # only allow letters, numbers, underscores, hyphens and dots
36
+ # dots are only allowed once in the extension
37
+ filename = re.sub(r"[^a-zA-Z0-9_.-]", "", filename)
38
+
39
+ # prevent multiple consecutive dots
40
+ while ".." in filename:
41
+ filename = filename.replace("..", ".")
42
+
43
+ # prevent starting with a dot (hidden file)
44
+ if filename.startswith("."):
45
+ filename = filename[1:]
46
+
47
+ # limit file name length
48
+ if len(filename) > 255:
49
+ filename = filename[:255]
50
+
51
+ return filename
52
+
53
+
54
+ def sanitize_folder_name(folder_name: str) -> str:
55
+ """
56
+ Sanitize folder name to prevent path traversal attacks
57
+ Args:
58
+ folder_name: original folder name
59
+ Returns:
60
+ sanitized folder name
61
+ """
62
+ if not folder_name:
63
+ return "output" # default folder
64
+
65
+ # remove all path traversal characters
66
+ folder_name = re.sub(r"\.\.(/|\\\\\\)?", "", folder_name)
67
+ folder_name = folder_name.strip("./\\")
68
+
69
+ # only allow letters, numbers, underscores, hyphens and slashes (for subdirectories)
70
+ # but need to ensure slashes don't cause path traversal
71
+ folder_name = re.sub(r"[^a-zA-Z0-9_./-]", "", folder_name)
72
+
73
+ # split path and clean each part
74
+ parts = folder_name.split("/")
75
+ cleaned_parts = []
76
+ for part in parts:
77
+ if part and part not in [".", ".."]:
78
+ # clean each part
79
+ part = re.sub(r"[^a-zA-Z0-9_-]", "", part)
80
+ if part:
81
+ cleaned_parts.append(part)
82
+
83
+ # recombine, allow at most 3 levels of directory depth
84
+ if len(cleaned_parts) > 3:
85
+ cleaned_parts = cleaned_parts[:3]
86
+
87
+ return "/".join(cleaned_parts) if cleaned_parts else "output"
88
+
89
+
90
+ def safe_path_join(base_dir: str, *paths: str) -> str:
91
+ """
92
+ Safe path joining, ensure the resulting path is within base_dir
93
+ Args:
94
+ base_dir: base directory
95
+ *paths: paths to join
96
+ Returns:
97
+ joined path
98
+ Raises:
99
+ ValueError: if path traversal is detected
100
+ """
101
+ # clean all paths
102
+ cleaned_paths = []
103
+ for path in paths:
104
+ if path:
105
+ # clean each path part
106
+ path = re.sub(r"\.\.(/|\\\\\\)?", "", path)
107
+ path = path.strip("./\\")
108
+ path = re.sub(r"[^a-zA-Z0-9_.-]", "", path)
109
+ if path:
110
+ cleaned_paths.append(path)
111
+
112
+ # join paths
113
+ full_path = os.path.join(base_dir, *cleaned_paths)
114
+
115
+ # ensure the resulting path is within base_dir
116
+ base_dir = os.path.realpath(base_dir)
117
+ full_path = os.path.realpath(os.path.normpath(full_path))
118
+
119
+ if os.path.commonpath([base_dir, full_path]) != base_dir:
120
+ raise ValueError(f"Path traversal detected: {full_path} is outside {base_dir}")
121
+
122
+ return full_path
123
+
124
+
125
+ def _get_root_dir() -> str:
126
+ return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
127
+
128
+
129
+ def get_output_dir(sub_path: str = "") -> str:
130
+ output_base = _get_root_dir()
131
+ if not os.path.exists(output_base):
132
+ os.makedirs(output_base, exist_ok=True)
133
+ if sub_path:
134
+ parts = [p for p in sub_path.replace("\\", "/").split("/") if p]
135
+ else:
136
+ parts = []
137
+ return safe_path_join(output_base, *parts)
138
+
139
+
140
+ def save_visualization_data(
141
+ output: Dict[str, Union[Tensor, list[str]]],
142
+ text: str,
143
+ rewritten_text: Union[str, list[str]],
144
+ timestamp: str,
145
+ output_dir: Optional[str] = None,
146
+ output_filename: Optional[str] = None,
147
+ ):
148
+ from ..pipeline.body_model import construct_smpl_data_dict
149
+
150
+ if output_dir is None:
151
+ output_dir = get_output_dir(sub_path="output/gradio")
152
+ os.makedirs(output_dir, exist_ok=True)
153
+
154
+ # for metadata
155
+ base_filename = output_filename if output_filename else timestamp
156
+ meta_path = safe_path_join(output_dir, f"{base_filename}_meta.json")
157
+ if isinstance(rewritten_text, str):
158
+ rewritten_text = [rewritten_text]
159
+ batch_size = output["rot6d"].shape[0]
160
+ meta_data = {
161
+ "timestamp": timestamp,
162
+ "text": text,
163
+ "text_rewrite": rewritten_text,
164
+ "num_samples": batch_size,
165
+ "base_filename": base_filename,
166
+ }
167
+
168
+ with _FILE_ACCESS_LOCK:
169
+ with open(meta_path, "w") as f:
170
+ json.dump(meta_data, f, indent=2)
171
+
172
+ # for smpl data
173
+ rot6d = output["rot6d"]
174
+ transl = output["transl"]
175
+
176
+ all_smpl_data = [] # for FBX generator
177
+
178
+ for bb in range(batch_size):
179
+ # build data
180
+ smpl_data = construct_smpl_data_dict(rot6d[bb].clone(), transl[bb].clone())
181
+ all_smpl_data.append(smpl_data)
182
+
183
+ # prepare dictionary to save into NPZ
184
+ npz_dict = {}
185
+ npz_dict["gender"] = np.array([smpl_data.get("gender", "neutral")], dtype=str)
186
+
187
+ for key in ["Rh", "trans", "poses", "betas"]:
188
+ if key in smpl_data:
189
+ val = smpl_data[key]
190
+ if isinstance(val, (list, tuple)):
191
+ val = np.array(val)
192
+ elif isinstance(val, torch.Tensor):
193
+ val = val.cpu().numpy()
194
+ npz_dict[key] = val
195
+
196
+ # save single NPZ
197
+ sample_filename = f"{base_filename}_{bb:03d}.npz"
198
+ sample_path = safe_path_join(output_dir, sample_filename)
199
+
200
+ with _FILE_ACCESS_LOCK:
201
+ np.savez_compressed(sample_path, **npz_dict)
202
+
203
+ # construct memory dictionary to return (for compatibility)
204
+ memory_data = {
205
+ "timestamp": timestamp,
206
+ "text": text,
207
+ "text_rewrite": rewritten_text,
208
+ "smpl_data": all_smpl_data,
209
+ "meta_data": [],
210
+ }
211
+
212
+ # return base filename, subsequent logic will use this as a basis for finding _meta.json or _000.npz
213
+ return memory_data, base_filename
214
+
215
+
216
+ def get_cached_captions(folder_name: str, file_name: str) -> List[dict]:
217
+ """read _meta.json to get text"""
218
+
219
+ folder_name = sanitize_folder_name(folder_name)
220
+ file_name = sanitize_filename(file_name)
221
+
222
+ base_dir = get_output_dir(folder_name)
223
+ # try to add suffix or find
224
+ meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
225
+
226
+ if not os.path.exists(meta_path):
227
+ if "_" in file_name:
228
+ prefix = file_name.rsplit("_", 1)[0]
229
+ prefix = sanitize_filename(prefix)
230
+ meta_path_alt = safe_path_join(base_dir, f"{prefix}_meta.json")
231
+ if os.path.exists(meta_path_alt):
232
+ meta_path = meta_path_alt
233
+ else:
234
+ return []
235
+ else:
236
+ return []
237
+
238
+ try:
239
+ with _FILE_ACCESS_LOCK:
240
+ with open(meta_path, "r") as f:
241
+ data = json.load(f)
242
+
243
+ text = data.get("text", "")
244
+ text_rewrite = data.get("text_rewrite", [])
245
+
246
+ captions = []
247
+ for i, t in enumerate(text_rewrite):
248
+ item = {"short caption+": f"{t}", "start_time": None, "end_time": None}
249
+ if text and text != t:
250
+ item["short caption"] = text
251
+ captions.append(item)
252
+ return captions
253
+ except Exception as e:
254
+ print(f"Error reading meta json: {e}")
255
+ return []
256
+
257
+
258
+ def get_cached_smpl_frames(folder_name: str, file_name: str) -> List[list]:
259
+ """
260
+ read logic needs to be adjusted:
261
+ 1. if file_name is the base name, load all samples
262
+ 2. if file_name is a specific sample name, only load that sample
263
+ """
264
+ folder_name = sanitize_folder_name(folder_name)
265
+ file_name = sanitize_filename(file_name)
266
+
267
+ base_dir = get_output_dir(folder_name)
268
+
269
+ npz_direct_path = safe_path_join(base_dir, f"{file_name}.npz")
270
+ meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
271
+
272
+ target_indices = []
273
+ base_name = file_name
274
+
275
+ if os.path.isfile(npz_direct_path):
276
+ try:
277
+ if "_" in file_name:
278
+ prefix, suffix = file_name.rsplit("_", 1)
279
+ if suffix.isdigit():
280
+ num_samples = 1
281
+ base_name = prefix
282
+ target_indices = [int(suffix)]
283
+ else:
284
+ pass
285
+ else:
286
+ pass
287
+ except ValueError:
288
+ pass
289
+ if not target_indices:
290
+ return []
291
+ elif os.path.exists(meta_path):
292
+ try:
293
+ with open(meta_path, "r") as f:
294
+ meta = json.load(f)
295
+ num_samples = meta.get("num_samples", 0)
296
+ target_indices = range(num_samples)
297
+ except Exception as e:
298
+ print(f"Error reading meta: {e}")
299
+ return []
300
+ else:
301
+ return []
302
+
303
+ all_people = []
304
+
305
+ for i in target_indices:
306
+ npz_path = safe_path_join(base_dir, f"{base_name}_{i:03d}.npz")
307
+ if not os.path.exists(npz_path):
308
+ continue
309
+
310
+ try:
311
+ with _FILE_ACCESS_LOCK:
312
+ with np.load(npz_path, allow_pickle=False) as data:
313
+ # read single person data
314
+ gender = str(data["gender"][0])
315
+ Rh = data["Rh"]
316
+ Th = data["trans"]
317
+ poses = data["poses"]
318
+ betas = data["betas"]
319
+
320
+ if poses.ndim == 3:
321
+ poses = poses.reshape(poses.shape[0], -1)
322
+
323
+ person_frames = []
324
+ for f in range(len(poses)):
325
+ frame = {
326
+ "id": i,
327
+ "gender": gender,
328
+ "Rh": Rh[f : f + 1].tolist(),
329
+ "Th": Th[f : f + 1].tolist(),
330
+ "poses": poses[f : f + 1].tolist(),
331
+ "shapes": betas.tolist(),
332
+ }
333
+ person_frames.append([frame])
334
+ all_people.append(person_frames)
335
+ except Exception as e:
336
+ print(f"Error loading {npz_path}: {e}")
337
+
338
+ # merge
339
+ combined_frames = []
340
+ max_frames = max(len(p) for p in all_people) if all_people else 0
341
+ for f_idx in range(max_frames):
342
+ frame_content = []
343
+ for person_seq in all_people:
344
+ if f_idx < len(person_seq):
345
+ frame_content.extend(person_seq[f_idx])
346
+ combined_frames.append(frame_content)
347
+
348
+ return combined_frames
349
+
350
+
351
+ def generate_static_html_content(
352
+ folder_name: str,
353
+ file_name: str,
354
+ hide_captions: bool = False,
355
+ ) -> str:
356
+ """
357
+ Generate static HTML content with embedded SMPL data and captions.
358
+ All JavaScript code is embedded directly in the HTML template,
359
+ so no external static resources are needed.
360
+
361
+ Args:
362
+ folder_name: The folder name containing the NPZ/meta files
363
+ file_name: The base file name (without extension)
364
+ hide_captions: Whether to hide captions in the visualization
365
+
366
+ Returns:
367
+ The HTML content as a string
368
+ """
369
+ # Load SMPL data
370
+ smpl_frames = get_cached_smpl_frames(folder_name, file_name)
371
+ if not smpl_frames:
372
+ raise ValueError(f"No SMPL data found for {folder_name}/{file_name}")
373
+
374
+ # Load captions
375
+ captions = []
376
+ if not hide_captions:
377
+ captions = get_cached_captions(folder_name, file_name)
378
+
379
+ # Generate caption HTML
380
+ caption_html = _generate_caption_html(captions, hide_captions)
381
+
382
+ # Convert SMPL data to JSON
383
+ smpl_data_json = json.dumps(smpl_frames, ensure_ascii=False)
384
+
385
+ # Load template
386
+ template_path = os.path.join(_TEMPLATE_DIR, "index_wooden_static.html")
387
+ with open(template_path, "r", encoding="utf-8") as f:
388
+ template_content = f.read()
389
+
390
+ # Replace placeholders with actual data
391
+ html_content = template_content.replace("{{ smpl_data_json }}", smpl_data_json)
392
+ html_content = html_content.replace("{{ caption_html }}", caption_html)
393
+
394
+ print(f">>> Generated static HTML content for {folder_name}/{file_name}")
395
+ return html_content
396
+
397
+
398
+ def generate_static_html(
399
+ folder_name: str,
400
+ file_name: str,
401
+ output_dir: str,
402
+ hide_captions: bool = False,
403
+ ) -> str:
404
+ """
405
+ Generate a static HTML file with embedded SMPL data and captions.
406
+ All JavaScript code is embedded directly in the HTML template,
407
+ so no external static resources are needed.
408
+
409
+ Args:
410
+ folder_name: The folder name containing the NPZ/meta files
411
+ file_name: The base file name (without extension)
412
+ output_dir: Directory to save the generated HTML file
413
+ hide_captions: Whether to hide captions in the visualization
414
+
415
+ Returns:
416
+ The path to the generated HTML file
417
+ """
418
+ html_content = generate_static_html_content(folder_name, file_name, hide_captions)
419
+
420
+ # Generate output path
421
+ os.makedirs(output_dir, exist_ok=True)
422
+ output_html_path = os.path.join(output_dir, f"{file_name}_vis.html")
423
+
424
+ # Write HTML file
425
+ with _FILE_ACCESS_LOCK:
426
+ with open(output_html_path, "w", encoding="utf-8") as f:
427
+ f.write(html_content)
428
+
429
+ print(f">>> Generated static HTML: {output_html_path}")
430
+ return output_html_path
431
+
432
+
433
+ def _generate_caption_html(captions: List[dict], hide_captions: bool = False) -> str:
434
+ """
435
+ Generate the caption overlay HTML.
436
+
437
+ Args:
438
+ captions: List of caption dictionaries
439
+ hide_captions: Whether to hide captions
440
+
441
+ Returns:
442
+ HTML string for caption overlay
443
+ """
444
+ if hide_captions or not captions:
445
+ return ""
446
+
447
+ caption_items = []
448
+ for caption in captions:
449
+ # Get the display text (prefer rewritten text)
450
+ text = caption.get("short caption+") or caption.get("short caption") or "No caption"
451
+ caption_items.append(f'<div class="caption-item">{text}</div>')
452
+
453
+ captions_html = "\n".join(caption_items)
454
+
455
+ return f'''
456
+ <div class="caption-overlay">
457
+ <div class="motion-info" id="motion-info">
458
+ <div class="captions-section">
459
+ {captions_html}
460
+ </div>
461
+ </div>
462
+ </div>
463
+ '''
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://gitlab.inria.fr/api/v4/projects/18692/packages/pypi/simple
2
+ huggingface_hub==0.30.0
3
+
4
+ torch==2.5.1
5
+ torchvision==0.20.1
6
+ accelerate==0.30.1
7
+ diffusers==0.26.3
8
+ transformers==4.53.3
9
+ einops==0.8.1
10
+ safetensors==0.5.3
11
+
12
+ numpy>=1.24.0,<2.0
13
+ scipy>=1.10.0
14
+ transforms3d==0.4.2
15
+
16
+ PyYAML==6.0
17
+ omegaconf==2.3.0
18
+ click==8.1.3
19
+ requests==2.32.4
20
+ openai==1.78.1
21
+
22
+ fbxsdkpy==2020.1.post2
23
+
24
+ torchdiffeq==0.2.5
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp ADDED

Git LFS Details

  • SHA256: aaed26cb89635f9d995c9c373919185c412f00f6b4a903873ac75ccd7a549439
  • Pointer size: 131 Bytes
  • Size of remote file: 228 kB
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp ADDED

Git LFS Details

  • SHA256: fe4bd8b80aadf6e414c3c07acc57532cda754deb93f0cdc6f211387b8beca97d
  • Pointer size: 130 Bytes
  • Size of remote file: 30.4 kB
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp ADDED

Git LFS Details

  • SHA256: 9ad8bd31dfda3a8356ac5c9d6bdc7457f0197ae23754f9ac196f86d872b26781
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
scripts/gradio/static/assets/dump_wooden/faces.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:777b0806d2843c797ed18644eecc11466ff822b33b02d263c22f8ad3730e9bb5
3
+ size 290376
scripts/gradio/static/assets/dump_wooden/j_template.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
3
+ size 624
scripts/gradio/static/assets/dump_wooden/joint_names.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Pelvis",
3
+ "L_Hip",
4
+ "R_Hip",
5
+ "Spine1",
6
+ "L_Knee",
7
+ "R_Knee",
8
+ "Spine2",
9
+ "L_Ankle",
10
+ "R_Ankle",
11
+ "Spine3",
12
+ "L_Foot",
13
+ "R_Foot",
14
+ "Neck",
15
+ "L_Collar",
16
+ "R_Collar",
17
+ "Head",
18
+ "L_Shoulder",
19
+ "R_Shoulder",
20
+ "L_Elbow",
21
+ "R_Elbow",
22
+ "L_Wrist",
23
+ "R_Wrist",
24
+ "L_Index1",
25
+ "L_Index2",
26
+ "L_Index3",
27
+ "L_Middle1",
28
+ "L_Middle2",
29
+ "L_Middle3",
30
+ "L_Pinky1",
31
+ "L_Pinky2",
32
+ "L_Pinky3",
33
+ "L_Ring1",
34
+ "L_Ring2",
35
+ "L_Ring3",
36
+ "L_Thumb1",
37
+ "L_Thumb2",
38
+ "L_Thumb3",
39
+ "R_Index1",
40
+ "R_Index2",
41
+ "R_Index3",
42
+ "R_Middle1",
43
+ "R_Middle2",
44
+ "R_Middle3",
45
+ "R_Pinky1",
46
+ "R_Pinky2",
47
+ "R_Pinky3",
48
+ "R_Ring1",
49
+ "R_Ring2",
50
+ "R_Ring3",
51
+ "R_Thumb1",
52
+ "R_Thumb2",
53
+ "R_Thumb3"
54
+ ]
scripts/gradio/static/assets/dump_wooden/joints.ply ADDED
Binary file (782 Bytes). View file
 
scripts/gradio/static/assets/dump_wooden/keypoints.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
3
+ size 624
scripts/gradio/static/assets/dump_wooden/kintree.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98a20fa3b53193790b63d9ac3a9c917f2f70fbe6e053dca495c75317ff4b756a
3
+ size 208
scripts/gradio/static/assets/dump_wooden/skinIndice.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:846794fb90ea01e069435ad242caefb3d5c2f913fef3255c247f48836ddc1bda
3
+ size 194048
scripts/gradio/static/assets/dump_wooden/skinWeights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:343eac2902627ee6b45b547eb7d9f1526562eca7ba178d1dc71f9e466f307a77
3
+ size 388096
scripts/gradio/static/assets/dump_wooden/uvs.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23c92c60b609261d927990d31cbf6ace0c14cb70a5ac753a5f3927cb8c5c8191
3
+ size 194048
scripts/gradio/static/assets/dump_wooden/v_template.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9fbe1f34bfe8a07442d11166e169318022a18da8bc62ce0a9930dfdb3171050
3
+ size 291072
scripts/gradio/templates/index_wooden_static.html ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+
4
+ <head>
5
+ <title>Motion Visualization</title>
6
+ <meta charset="UTF-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
9
+ <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
10
+ <script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.10.2/dist/umd/popper.min.js"></script>
11
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.min.js"></script>
12
+ <style>
13
+ html, body {
14
+ background: #1a1a2e !important;
15
+ color: #e2e8f0;
16
+ margin: 0;
17
+ padding: 0;
18
+ }
19
+ .container {
20
+ padding: 0;
21
+ border: none;
22
+ background: #1a1a2e;
23
+ }
24
+ .alert-success {
25
+ display: none;
26
+ }
27
+ </style>
28
+ </head>
29
+
30
+ <body>
31
+
32
+ <!-- Fullscreen 3D container -->
33
+ <div class="fullscreen-container">
34
+ <!-- 3D viewport -->
35
+ <div id="vis3d"></div>
36
+
37
+ <!-- Floating caption overlay (centered at top) -->
38
+ {{ caption_html }}
39
+
40
+ <!-- Floating progress control panel (centered at bottom) -->
41
+ <div class="control-overlay">
42
+ <div class="control-row-minimal">
43
+ <div class="progress-container">
44
+ <input type="range" id="progressSlider" class="progress-slider-minimal" min="0" max="100" value="0">
45
+ </div>
46
+ <div class="frame-counter">
47
+ <span id="currentFrame">0</span> / <span id="totalFrames">0</span>
48
+ </div>
49
+ </div>
50
+ </div>
51
+
52
+ <!-- Loading status overlay -->
53
+ <div class="loading-overlay" id="loadingStatus">
54
+ <i class="fas fa-spinner fa-spin"></i> Loading...
55
+ </div>
56
+
57
+ <!-- Hidden controls for functionality -->
58
+ <div style="display: none;">
59
+ <button id="playPauseBtn"></button>
60
+ <button id="resetBtn"></button>
61
+ <input type="range" id="speedSlider" min="0.1" max="3" step="0.1" value="1">
62
+ <span id="speedValue">1.0x</span>
63
+ </div>
64
+ </div>
65
+
66
+ <!-- Add Font Awesome for icons -->
67
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
68
+
69
+ <script type="importmap">
70
+ {
71
+ "imports": {
72
+ "three": "https://cdn.jsdelivr.net/npm/three@0.160.0/build/three.module.js",
73
+ "three/addons/": "https://cdn.jsdelivr.net/npm/three@0.160.0/examples/jsm/"
74
+ }
75
+ }
76
+ </script>
77
+
78
+ <!-- Embedded SMPL Data - Generated by Python -->
79
+ <script type="application/json" id="smpl-data-json">
80
+ {{ smpl_data_json }}
81
+ </script>
82
+
83
+ <script type="module">
84
+ import * as THREE from 'three';
85
+ import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
86
+
87
+ // ============================================================
88
+ // EMBEDDED: create_ground.js functions
89
+ // ============================================================
90
+
91
+ function getAdaptiveGridSize(sample_data, default_size = 5) {
92
+ if (sample_data) {
93
+ const bounds = calculateDataBounds(sample_data);
94
+ const grid_size = Math.max(bounds.maxRange * 3, 5);
95
+ console.log(`Adaptive ground size: ${grid_size.toFixed(2)}, data range: ${bounds.maxRange.toFixed(2)}`);
96
+ return grid_size;
97
+ }
98
+ return default_size;
99
+ }
100
+
101
+ function createBaseChessboard(
102
+ grid_size = 5,
103
+ divisions = 10,
104
+ white = "#ffffff",
105
+ black = "#444444",
106
+ texture_size = 1024,
107
+ sample_data = null,
108
+ ) {
109
+ if (sample_data) {
110
+ grid_size = getAdaptiveGridSize(sample_data, grid_size);
111
+ }
112
+
113
+ var adjusted_texture_size = Math.floor(texture_size / divisions) * divisions;
114
+ var canvas = document.createElement("canvas");
115
+ canvas.width = canvas.height = adjusted_texture_size;
116
+ var context = canvas.getContext("2d");
117
+ context.imageSmoothingEnabled = false;
118
+
119
+ var step = adjusted_texture_size / divisions;
120
+ for (var i = 0; i < divisions; i++) {
121
+ for (var j = 0; j < divisions; j++) {
122
+ context.fillStyle = (i + j) % 2 === 0 ? white : black;
123
+ context.fillRect(i * step, j * step, step, step);
124
+ }
125
+ }
126
+
127
+ var texture = new THREE.CanvasTexture(canvas);
128
+ texture.wrapS = THREE.RepeatWrapping;
129
+ texture.wrapT = THREE.RepeatWrapping;
130
+ texture.magFilter = THREE.NearestFilter;
131
+ texture.minFilter = THREE.NearestFilter;
132
+ texture.generateMipmaps = false;
133
+
134
+ var planeGeometry = new THREE.PlaneGeometry(grid_size, grid_size);
135
+
136
+ var planeMaterial = new THREE.MeshStandardMaterial({
137
+ map: texture,
138
+ side: THREE.DoubleSide,
139
+ transparent: true,
140
+ opacity: 0.85,
141
+ roughness: 0.9,
142
+ metalness: 0.1,
143
+ emissiveIntensity: 0.05,
144
+ });
145
+
146
+ var plane = new THREE.Mesh(planeGeometry, planeMaterial);
147
+ plane.receiveShadow = true;
148
+
149
+ return plane;
150
+ }
151
+
152
+ function getChessboard(...args) {
153
+ var plane = createBaseChessboard(...args);
154
+ plane.rotation.x = -Math.PI;
155
+ return plane;
156
+ }
157
+
158
+ function getChessboardXZ(...args) {
159
+ var plane = createBaseChessboard(...args);
160
+ plane.rotation.x = -Math.PI / 2;
161
+ return plane;
162
+ }
163
+
164
+ function getCoordinate(axisLength) {
165
+ var axes = new THREE.Group();
166
+ var materialX = new THREE.LineBasicMaterial({ color: 0xff0000 });
167
+ var materialY = new THREE.LineBasicMaterial({ color: 0x00ff00 });
168
+ var materialZ = new THREE.LineBasicMaterial({ color: 0x0000ff });
169
+
170
+ var xAxisGeometry = new THREE.BufferGeometry().setFromPoints([
171
+ new THREE.Vector3(0, 0, 0),
172
+ new THREE.Vector3(axisLength, 0, 0),
173
+ ]);
174
+ var yAxisGeometry = new THREE.BufferGeometry().setFromPoints([
175
+ new THREE.Vector3(0, 0, 0),
176
+ new THREE.Vector3(0, axisLength, 0),
177
+ ]);
178
+ var zAxisGeometry = new THREE.BufferGeometry().setFromPoints([
179
+ new THREE.Vector3(0, 0, 0),
180
+ new THREE.Vector3(0, 0, axisLength),
181
+ ]);
182
+
183
+ var xAxis = new THREE.Line(xAxisGeometry, materialX);
184
+ var yAxis = new THREE.Line(yAxisGeometry, materialY);
185
+ var zAxis = new THREE.Line(zAxisGeometry, materialZ);
186
+
187
+ axes.add(xAxis);
188
+ axes.add(yAxis);
189
+ axes.add(zAxis);
190
+
191
+ return axes;
192
+ }
193
+
194
+ function calculateDataBounds(sample_data) {
195
+ let minX = Infinity, maxX = -Infinity;
196
+ let minY = Infinity, maxY = -Infinity;
197
+ let minZ = Infinity, maxZ = -Infinity;
198
+
199
+ if (sample_data && sample_data.length > 0) {
200
+ sample_data.forEach((frame) => {
201
+ if (frame.positions && Array.isArray(frame.positions)) {
202
+ frame.positions.forEach((pos) => {
203
+ let x, y, z;
204
+ if (typeof pos === "object") {
205
+ x = pos.x !== undefined ? pos.x : pos[0];
206
+ y = pos.y !== undefined ? pos.y : pos[1];
207
+ z = pos.z !== undefined ? pos.z : pos[2];
208
+ } else if (Array.isArray(pos)) {
209
+ [x, y, z] = pos;
210
+ }
211
+
212
+ if (x !== undefined && y !== undefined && z !== undefined) {
213
+ minX = Math.min(minX, x);
214
+ maxX = Math.max(maxX, x);
215
+ minY = Math.min(minY, y);
216
+ maxY = Math.max(maxY, y);
217
+ minZ = Math.min(minZ, z);
218
+ maxZ = Math.max(maxZ, z);
219
+ }
220
+ });
221
+ }
222
+ });
223
+ }
224
+
225
+ if (minX === Infinity || maxX === -Infinity) {
226
+ minX = maxX = minY = maxY = minZ = maxZ = 0;
227
+ }
228
+
229
+ const rangeX = Math.abs(maxX - minX);
230
+ const rangeY = Math.abs(maxY - minY);
231
+ const rangeZ = Math.abs(maxZ - minZ);
232
+ const maxRange = Math.max(rangeX, rangeZ);
233
+
234
+ return { minX, maxX, minY, maxY, minZ, maxZ, rangeX, rangeY, rangeZ, maxRange };
235
+ }
236
+
237
+ // ============================================================
238
+ // EMBEDDED: create_scene.js functions
239
+ // ============================================================
240
+
241
+ function create_scene(scene, camera, renderer, use_ground = true, axis_up = "z", axis_forward = "-y") {
242
+ const width = document.querySelector(".container") ? document.querySelector(".container").offsetWidth : window.innerWidth;
243
+ const height = width;
244
+
245
+ if (axis_up == "z") {
246
+ camera.up.set(0, 0, 1);
247
+ if (axis_forward == "-y") {
248
+ camera.position.set(0, -3, 3);
249
+ } else if (axis_forward == "y") {
250
+ camera.position.set(0, 3, 3);
251
+ }
252
+ camera.lookAt(new THREE.Vector3(0, 0, 1.5));
253
+ } else if (axis_up == "y") {
254
+ camera.up.set(0, 1, 0);
255
+ if (axis_forward == "z") {
256
+ camera.position.set(0, 2.5, 5);
257
+ } else if (axis_forward == "-z") {
258
+ camera.position.set(0, 2.5, -5);
259
+ }
260
+ camera.lookAt(new THREE.Vector3(0, 1, 0));
261
+ }
262
+
263
+ scene.background = new THREE.Color(0x000000);
264
+ scene.fog = new THREE.FogExp2(0x424242, 0.06);
265
+
266
+ renderer.shadowMap.enabled = true;
267
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
268
+
269
+ const hemisphereLight = new THREE.HemisphereLight(0xffffff, 0x444444, 1.8);
270
+ hemisphereLight.position.set(0, 2, 0);
271
+ scene.add(hemisphereLight);
272
+
273
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
274
+ if (axis_up == "z") {
275
+ if (axis_forward == "-y") {
276
+ directionalLight.position.set(-3, 1, 5);
277
+ } else if (axis_forward == "y") {
278
+ directionalLight.position.set(3, 1, 5);
279
+ }
280
+ } else if (axis_up == "y") {
281
+ if (axis_forward == "z") {
282
+ directionalLight.position.set(3, 5, 4);
283
+ } else if (axis_forward == "-z") {
284
+ directionalLight.position.set(3, 5, -4);
285
+ }
286
+ }
287
+ directionalLight.castShadow = true;
288
+ directionalLight.shadow.mapSize.width = 2048;
289
+ directionalLight.shadow.mapSize.height = 2048;
290
+ directionalLight.shadow.camera.near = 0.5;
291
+ directionalLight.shadow.camera.far = 50;
292
+ directionalLight.shadow.camera.left = -10;
293
+ directionalLight.shadow.camera.right = 10;
294
+ directionalLight.shadow.camera.top = 10;
295
+ directionalLight.shadow.camera.bottom = -10;
296
+ directionalLight.shadow.bias = -0.0001;
297
+ scene.add(directionalLight);
298
+
299
+ const fillLight = new THREE.DirectionalLight(0xaaccff, 0.4);
300
+ fillLight.position.set(-3, 3, -2);
301
+ scene.add(fillLight);
302
+
303
+ const rimLight = new THREE.DirectionalLight(0xffeedd, 0.3);
304
+ rimLight.position.set(0, 4, -5);
305
+ scene.add(rimLight);
306
+
307
+ if (use_ground) {
308
+ if (axis_up == "z") {
309
+ var plane = getChessboard(50, 50, '#ffffff', '#3a3a3a', 1024);
310
+ plane.name = 'ground';
311
+ plane.receiveShadow = true;
312
+ scene.add(plane);
313
+ } else if (axis_up == "y") {
314
+ var plane = getChessboardXZ(50, 50, '#ffffff', '#3a3a3a', 1024);
315
+ plane.name = 'ground';
316
+ plane.receiveShadow = true;
317
+ scene.add(plane);
318
+ }
319
+ }
320
+
321
+ return 0;
322
+ }
323
+
324
+ function fitCameraToScene(scene, camera, controls = null, opts = {}) {
325
+ const { margin = 1.05, axis_up = "y", excludeNames = ["ground"] } = opts;
326
+
327
+ const box = new THREE.Box3();
328
+ const tmp = new THREE.Box3();
329
+ let has = false;
330
+
331
+ scene.traverse((obj) => {
332
+ if (!obj || !obj.visible) return;
333
+ if (obj.isLight) return;
334
+ const t = obj.type || "";
335
+ if (t.endsWith("Helper")) return;
336
+ if (excludeNames && excludeNames.includes(obj.name)) return;
337
+
338
+ if (obj.isMesh) {
339
+ if (obj.geometry && obj.geometry.type === "PlaneGeometry") return;
340
+ try {
341
+ tmp.setFromObject(obj);
342
+ if (!tmp.isEmpty()) {
343
+ if (!has) {
344
+ box.copy(tmp);
345
+ has = true;
346
+ } else {
347
+ box.union(tmp);
348
+ }
349
+ }
350
+ } catch (_) {}
351
+ }
352
+ });
353
+
354
+ if (!has || box.isEmpty()) return;
355
+
356
+ const sphere = new THREE.Sphere();
357
+ box.getBoundingSphere(sphere);
358
+ const center = sphere.center.clone();
359
+ const radius = Math.max(sphere.radius, 1e-3);
360
+
361
+ const vFov = THREE.MathUtils.degToRad(camera.fov);
362
+ const hFov = 2 * Math.atan(Math.tan(vFov / 2) * camera.aspect);
363
+ const distV = radius / Math.sin(vFov / 2);
364
+ const distH = radius / Math.sin(hFov / 2);
365
+ const dist = Math.max(distV, distH) * margin;
366
+
367
+ const elev = THREE.MathUtils.degToRad(25);
368
+ const azim = Math.PI / 4;
369
+ const horiz = Math.cos(elev);
370
+ let dir;
371
+
372
+ if (axis_up === "y") {
373
+ dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.sin(elev), Math.cos(azim) * horiz);
374
+ camera.up.set(0, 1, 0);
375
+ } else {
376
+ dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.cos(azim) * horiz, Math.sin(elev));
377
+ camera.up.set(0, 0, 1);
378
+ }
379
+
380
+ camera.position.copy(center).add(dir.multiplyScalar(dist));
381
+ camera.updateProjectionMatrix();
382
+ camera.lookAt(center);
383
+
384
+ if (controls) {
385
+ controls.target.copy(center);
386
+ controls.minDistance = Math.max(radius * 0.2, 0.1);
387
+ controls.maxDistance = Math.max(dist * 3, controls.minDistance + 0.1);
388
+ controls.update();
389
+ }
390
+ }
391
+
392
+ // ============================================================
393
+ // EMBEDDED: load_wooden.js functions
394
+ // ============================================================
395
+
396
+ const NUM_SKIN_WEIGHTS = 4;
397
+
398
+ const SMPLH_JOINT_NAMES = [
399
+ "Pelvis", "L_Hip", "R_Hip", "Spine1",
400
+ "L_Knee", "R_Knee", "Spine2",
401
+ "L_Ankle", "R_Ankle", "Spine3",
402
+ "L_Foot", "R_Foot", "Neck", "L_Collar", "R_Collar", "Head",
403
+ "L_Shoulder", "R_Shoulder", "L_Elbow", "R_Elbow",
404
+ "L_Wrist", "R_Wrist",
405
+ "L_Index1", "L_Index2", "L_Index3",
406
+ "L_Middle1", "L_Middle2", "L_Middle3",
407
+ "L_Pinky1", "L_Pinky2", "L_Pinky3",
408
+ "L_Ring1", "L_Ring2", "L_Ring3",
409
+ "L_Thumb1", "L_Thumb2", "L_Thumb3",
410
+ "R_Index1", "R_Index2", "R_Index3",
411
+ "R_Middle1", "R_Middle2", "R_Middle3",
412
+ "R_Pinky1", "R_Pinky2", "R_Pinky3",
413
+ "R_Ring1", "R_Ring2", "R_Ring3",
414
+ "R_Thumb1", "R_Thumb2", "R_Thumb3",
415
+ ];
416
+
417
+ const DEFAULT_EDGES = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50];
418
+
419
+ async function load_wooden(shapes, gender, basePath = '/static/assets/dump_wooden') {
420
+ console.log("Loading wooden model...");
421
+ basePath = "https://raw.githubusercontent.com/chingswy/WoodenModel/refs/heads/main/dump_wooden"
422
+ console.log(`Using base path: ${basePath}`);
423
+
424
+ const urls = [
425
+ `${basePath}/v_template.bin`,
426
+ `${basePath}/faces.bin`,
427
+ `${basePath}/skinWeights.bin`,
428
+ `${basePath}/skinIndice.bin`,
429
+ `${basePath}/j_template.bin`,
430
+ `${basePath}/uvs.bin`,
431
+ ];
432
+
433
+ let edges = [...DEFAULT_EDGES];
434
+ try {
435
+ const kintreeResponse = await fetch(`${basePath}/kintree.bin`);
436
+ if (kintreeResponse.ok) {
437
+ const kintreeBuffer = await kintreeResponse.arrayBuffer();
438
+ edges = Array.from(new Int32Array(kintreeBuffer));
439
+ console.log(`Loaded kintree with ${edges.length} joints`);
440
+ }
441
+ } catch (e) {
442
+ console.log('Using default kintree');
443
+ }
444
+
445
+ let jointNames = [...SMPLH_JOINT_NAMES];
446
+ try {
447
+ const namesResponse = await fetch(`${basePath}/joint_names.json`);
448
+ if (namesResponse.ok) {
449
+ jointNames = await namesResponse.json();
450
+ console.log(`Loaded ${jointNames.length} joint names`);
451
+ }
452
+ } catch (e) {
453
+ console.log('Using default joint names');
454
+ }
455
+
456
+ const buffers = await Promise.all(urls.map(url => fetch(url).then(response => response.arrayBuffer())));
457
+ const v_template = new Float32Array(buffers[0]);
458
+ const faces = new Uint16Array(buffers[1]);
459
+ const skinWeights = new Float32Array(buffers[2]);
460
+ const skinIndices = new Uint16Array(buffers[3]);
461
+ const keypoints = new Float32Array(buffers[4]);
462
+ const uvs = new Float32Array(buffers[5]);
463
+
464
+ console.log(`Vertices: ${v_template.length / 3}, Faces: ${faces.length / 3}, Joints: ${keypoints.length / 3}`);
465
+
466
+ const geometry = new THREE.BufferGeometry();
467
+ geometry.setAttribute('position', new THREE.BufferAttribute(v_template, 3));
468
+ geometry.setIndex(new THREE.BufferAttribute(faces, 1));
469
+ geometry.setAttribute('skinIndex', new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS));
470
+ geometry.setAttribute('skinWeight', new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS));
471
+ geometry.setAttribute('uv', new THREE.BufferAttribute(uvs, 2));
472
+
473
+ const numJoints = keypoints.length / 3;
474
+
475
+ while (edges.length < numJoints) {
476
+ edges.push(0);
477
+ }
478
+
479
+ var rootBone = new THREE.Bone();
480
+ rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]);
481
+ rootBone.name = jointNames[0] || 'Pelvis';
482
+ var bones = [rootBone];
483
+
484
+ for (let i = 1; i < numJoints; i++) {
485
+ const bone = new THREE.Bone();
486
+ const parentIndex = edges[i];
487
+
488
+ if (parentIndex >= 0 && parentIndex < i) {
489
+ bone.position.set(
490
+ keypoints[3 * i] - keypoints[3 * parentIndex],
491
+ keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1],
492
+ keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2]
493
+ );
494
+ bone.name = jointNames[i] || `Joint_${i}`;
495
+ bones.push(bone);
496
+ bones[parentIndex].add(bone);
497
+ console.log(`Joint ${i} (${bone.name}): parent=${parentIndex}, pos=${bone.position.toArray()}`);
498
+ } else {
499
+ console.warn(`Invalid parent index ${parentIndex} for joint ${i}, attaching to root`);
500
+ bone.position.set(0, 0, 0);
501
+ bone.name = jointNames[i] || `Joint_${i}`;
502
+ bones.push(bone);
503
+ bones[0].add(bone);
504
+ }
505
+ }
506
+
507
+ var skeleton = new THREE.Skeleton(bones);
508
+
509
+ geometry.computeVertexNormals();
510
+
511
+ const textureLoader = new THREE.TextureLoader();
512
+
513
+ async function loadTextureAsync(url, isSRGB = true) {
514
+ const tex = await textureLoader.loadAsync(url);
515
+ tex.flipY = false;
516
+ if (isSRGB) tex.colorSpace = THREE.SRGBColorSpace;
517
+ return tex;
518
+ }
519
+
520
+ const [baseColorMap] = await Promise.all([
521
+ loadTextureAsync(`${basePath}/Boy_lambert4_BaseColor.webp`, true),
522
+ ]);
523
+
524
+ const material = new THREE.MeshStandardMaterial({
525
+ map: baseColorMap,
526
+ roughness: 0.6,
527
+ metalness: 0.2,
528
+ envMapIntensity: 1.5,
529
+ });
530
+
531
+ var mesh = new THREE.SkinnedMesh(geometry, material);
532
+ mesh.castShadow = true;
533
+ mesh.receiveShadow = true;
534
+ mesh.add(bones[0]);
535
+ mesh.bind(skeleton);
536
+
537
+ console.log(`Wooden model loaded: ${numJoints} joints, ${v_template.length / 3} vertices`);
538
+
539
+ return { bones, skeleton, mesh, jointNames, edges };
540
+ }
541
+
542
+ // ============================================================
543
+ // Main Application Code
544
+ // ============================================================
545
+
546
+ let scene, camera, renderer;
547
+ let controls;
548
+ let infos;
549
+ let currentFrame = 0;
550
+ let total_frame = 0;
551
+ const baseIntervalTime = 30;
552
+ var model_mesh = {};
553
+
554
+ let isPlaying = false;
555
+ let lastFrameTime = 0;
556
+ let playbackSpeed = 1.0;
557
+ let animationId = null;
558
+ let modelsLoaded = false;
559
+ let expectedModelCount = 0;
560
+ let loadedModelCount = 0;
561
+
562
+ let ignoreGlobalTrans = false;
563
+ let currentOffsets = [];
564
+
565
+ const updateFrame = () => {
566
+ if (!infos || currentFrame >= total_frame || !modelsLoaded) return;
567
+
568
+ const info = infos[currentFrame];
569
+ let allModelsReady = true;
570
+
571
+ info.forEach(smpl_params => {
572
+ if (!(smpl_params.id in model_mesh)) {
573
+ allModelsReady = false;
574
+ }
575
+ });
576
+
577
+ if (!allModelsReady) {
578
+ return;
579
+ }
580
+
581
+ const offsets = computeOffsets(info.length);
582
+ currentOffsets = offsets;
583
+
584
+ info.forEach((smpl_params, b) => {
585
+ const bones = model_mesh[smpl_params.id];
586
+ const meshContainer = bones[0].parent;
587
+
588
+ if (ignoreGlobalTrans) {
589
+ meshContainer.position.set(-offsets[b], 0, 0);
590
+ } else {
591
+ meshContainer.position.set(
592
+ smpl_params.Th[0][0] - offsets[b],
593
+ smpl_params.Th[0][1],
594
+ smpl_params.Th[0][2]
595
+ );
596
+ }
597
+
598
+ var axis = new THREE.Vector3(smpl_params.Rh[0][0], smpl_params.Rh[0][1], smpl_params.Rh[0][2]);
599
+ var angle = axis.length();
600
+ axis.normalize();
601
+ var quaternion = new THREE.Quaternion().setFromAxisAngle(axis, angle);
602
+ bones[0].quaternion.copy(quaternion);
603
+
604
+ var poses_offset = 0;
605
+
606
+ if (smpl_params.poses[0].length == 69) {
607
+ poses_offset = -3;
608
+ }
609
+
610
+ for (let i = 1; i < bones.length; i++) {
611
+ const startIndex = poses_offset + 3 * i;
612
+
613
+ if (startIndex + 2 < smpl_params.poses[0].length) {
614
+ var axis = new THREE.Vector3(
615
+ smpl_params.poses[0][startIndex],
616
+ smpl_params.poses[0][startIndex + 1],
617
+ smpl_params.poses[0][startIndex + 2]
618
+ );
619
+ var angle = axis.length();
620
+
621
+ if (angle > 1e-6) {
622
+ axis.normalize();
623
+ var quaternion = new THREE.Quaternion().setFromAxisAngle(axis, angle);
624
+ bones[i].quaternion.copy(quaternion);
625
+ } else {
626
+ bones[i].quaternion.set(0, 0, 0, 1);
627
+ }
628
+ }
629
+ }
630
+ });
631
+
632
+ updateUI();
633
+ }
634
+
635
+ const playLoop = (currentTime) => {
636
+ if (isPlaying && currentTime - lastFrameTime >= (baseIntervalTime / playbackSpeed)) {
637
+ currentFrame += 1;
638
+ if (currentFrame >= total_frame) {
639
+ currentFrame = 0;
640
+ }
641
+ updateFrame();
642
+ lastFrameTime = currentTime;
643
+ }
644
+
645
+ if (isPlaying) {
646
+ animationId = requestAnimationFrame(playLoop);
647
+ }
648
+ }
649
+
650
+ const updateUI = () => {
651
+ document.getElementById('currentFrame').textContent = currentFrame;
652
+ document.getElementById('totalFrames').textContent = total_frame;
653
+
654
+ if (total_frame > 0) {
655
+ const progress = (currentFrame / total_frame) * 100;
656
+ document.getElementById('progressSlider').value = progress;
657
+ }
658
+ }
659
+
660
+ const updateLoadingStatus = () => {
661
+ const loadingElement = document.getElementById('loadingStatus');
662
+ if (!loadingElement) return;
663
+
664
+ if (modelsLoaded) {
665
+ loadingElement.innerHTML = '<i class="fas fa-check"></i> Ready';
666
+ loadingElement.className = 'loading-overlay complete';
667
+ setTimeout(() => {
668
+ loadingElement.className = 'loading-overlay hidden';
669
+ }, 1500);
670
+ } else {
671
+ loadingElement.innerHTML = `<i class="fas fa-spinner fa-spin"></i> Loading... (${loadedModelCount}/${expectedModelCount})`;
672
+ loadingElement.className = 'loading-overlay';
673
+ }
674
+ }
675
+
676
+ const updatePlayPauseButton = () => {
677
+ const playPauseBtn = document.getElementById('playPauseBtn');
678
+ if (playPauseBtn) {
679
+ if (isPlaying) {
680
+ playPauseBtn.innerHTML = '<i class="fas fa-pause"></i>';
681
+ playPauseBtn.title = 'Pause';
682
+ } else {
683
+ playPauseBtn.innerHTML = '<i class="fas fa-play"></i>';
684
+ playPauseBtn.title = 'Play';
685
+ }
686
+ }
687
+ }
688
+
689
+ const enablePlaybackControls = () => {
690
+ const playPauseBtn = document.getElementById('playPauseBtn');
691
+ const resetBtn = document.getElementById('resetBtn');
692
+ const progressSlider = document.getElementById('progressSlider');
693
+ const speedSlider = document.getElementById('speedSlider');
694
+
695
+ [playPauseBtn, resetBtn, progressSlider, speedSlider].forEach(element => {
696
+ if (element) {
697
+ element.disabled = false;
698
+ element.style.opacity = '1';
699
+ element.style.cursor = 'pointer';
700
+ }
701
+ });
702
+
703
+ updatePlayPauseButton();
704
+ }
705
+
706
+ const playAnimation = () => {
707
+ if (!isPlaying && total_frame > 0 && modelsLoaded) {
708
+ isPlaying = true;
709
+ lastFrameTime = performance.now();
710
+ animationId = requestAnimationFrame(playLoop);
711
+ updatePlayPauseButton();
712
+ }
713
+ }
714
+
715
+ const pauseAnimation = () => {
716
+ isPlaying = false;
717
+ if (animationId) {
718
+ cancelAnimationFrame(animationId);
719
+ animationId = null;
720
+ }
721
+ updatePlayPauseButton();
722
+ }
723
+
724
+ const resetAnimation = () => {
725
+ pauseAnimation();
726
+ currentFrame = 0;
727
+ updateFrame();
728
+ updatePlayPauseButton();
729
+ }
730
+
731
+ const initPlaybackControls = () => {
732
+ const progressSlider = document.getElementById('progressSlider');
733
+
734
+ let wasPlaying = false;
735
+ progressSlider.addEventListener('mousedown', () => {
736
+ if (!modelsLoaded) return;
737
+ wasPlaying = isPlaying;
738
+ if (isPlaying) pauseAnimation();
739
+ });
740
+
741
+ progressSlider.addEventListener('input', (e) => {
742
+ if (!modelsLoaded) return;
743
+ const progress = parseFloat(e.target.value);
744
+ currentFrame = Math.floor((progress / 100) * total_frame);
745
+ if (currentFrame >= total_frame) currentFrame = total_frame - 1;
746
+ if (currentFrame < 0) currentFrame = 0;
747
+ updateFrame();
748
+ });
749
+
750
+ progressSlider.addEventListener('mouseup', () => {
751
+ if (!modelsLoaded) return;
752
+ if (wasPlaying) playAnimation();
753
+ });
754
+
755
+ progressSlider.addEventListener('touchstart', () => {
756
+ if (!modelsLoaded) return;
757
+ wasPlaying = isPlaying;
758
+ if (isPlaying) pauseAnimation();
759
+ });
760
+
761
+ progressSlider.addEventListener('touchend', () => {
762
+ if (!modelsLoaded) return;
763
+ if (wasPlaying) playAnimation();
764
+ });
765
+
766
+ const speedSlider = document.getElementById('speedSlider');
767
+ const speedValue = document.getElementById('speedValue');
768
+ speedSlider.addEventListener('input', (e) => {
769
+ playbackSpeed = parseFloat(e.target.value);
770
+ speedValue.textContent = playbackSpeed.toFixed(1) + 'x';
771
+ });
772
+
773
+ document.addEventListener('keydown', (e) => {
774
+ if (!modelsLoaded) return;
775
+ switch (e.code) {
776
+ case 'Space':
777
+ e.preventDefault();
778
+ if (isPlaying) {
779
+ pauseAnimation();
780
+ } else {
781
+ playAnimation();
782
+ }
783
+ break;
784
+ case 'ArrowLeft':
785
+ e.preventDefault();
786
+ if (currentFrame > 0) {
787
+ currentFrame--;
788
+ updateFrame();
789
+ }
790
+ break;
791
+ case 'ArrowRight':
792
+ e.preventDefault();
793
+ if (currentFrame < total_frame - 1) {
794
+ currentFrame++;
795
+ updateFrame();
796
+ }
797
+ break;
798
+ case 'Home':
799
+ e.preventDefault();
800
+ resetAnimation();
801
+ break;
802
+ }
803
+ });
804
+ }
805
+
806
+ // Load embedded SMPL data directly (no fetch needed)
807
+ function loadEmbeddedData() {
808
+ try {
809
+ const smplDataElement = document.getElementById('smpl-data-json');
810
+ if (!smplDataElement) {
811
+ console.error('SMPL data element not found');
812
+ return;
813
+ }
814
+
815
+ const datas = JSON.parse(smplDataElement.textContent);
816
+
817
+ if (!datas || datas.length === 0) {
818
+ console.error('No SMPL data available');
819
+ return;
820
+ }
821
+
822
+ console.log(`Loaded ${datas.length} frames of embedded SMPL data`);
823
+ infos = datas;
824
+ total_frame = datas.length;
825
+
826
+ document.getElementById('progressSlider').max = 100;
827
+ updateUI();
828
+ updatePlayPauseButton();
829
+
830
+ expectedModelCount = infos[0].length;
831
+
832
+ loadedModelCount = 0;
833
+ modelsLoaded = false;
834
+ updateLoadingStatus();
835
+
836
+ infos[0].forEach(data => {
837
+ load_wooden(null, null).then(result => {
838
+ scene.add(result.mesh);
839
+
840
+ result.mesh.castShadow = true;
841
+ result.mesh.receiveShadow = true;
842
+
843
+ model_mesh[data.id] = result.bones;
844
+
845
+ loadedModelCount++;
846
+
847
+ if (loadedModelCount === expectedModelCount) {
848
+ modelsLoaded = true;
849
+ updateLoadingStatus();
850
+ updateFrame();
851
+ enablePlaybackControls();
852
+ fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
853
+ setTimeout(() => playAnimation(), 500);
854
+ } else {
855
+ updateLoadingStatus();
856
+ }
857
+ }).catch(err => {
858
+ console.error("Failed to load wooden model:", err);
859
+ });
860
+ });
861
+
862
+ initPlaybackControls();
863
+ animate();
864
+ } catch (error) {
865
+ console.error('Error loading embedded data:', error);
866
+ }
867
+ }
868
+
869
+ init();
870
+ loadEmbeddedData();
871
+
872
+ function init() {
873
+ const width = window.innerWidth;
874
+ const height = window.innerHeight;
875
+ scene = new THREE.Scene();
876
+ camera = new THREE.PerspectiveCamera(45, width / height, 0.1, 50);
877
+ renderer = new THREE.WebGLRenderer({ antialias: true, logarithmicDepthBuffer: true });
878
+
879
+ create_scene(scene, camera, renderer, true, 'y', 'z');
880
+
881
+ renderer.shadowMap.enabled = true;
882
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
883
+
884
+ scene.background = new THREE.Color(0x424242);
885
+ scene.fog = new THREE.FogExp2(0x424242, 0.06);
886
+
887
+ scene.children = scene.children.filter(child => !child.isLight);
888
+
889
+ const hemisphereLight = new THREE.HemisphereLight(0xffffff, 0x444444, 1.2);
890
+ hemisphereLight.position.set(0, 2, 0);
891
+ scene.add(hemisphereLight);
892
+
893
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
894
+ directionalLight.position.set(3, 5, 4);
895
+ directionalLight.castShadow = true;
896
+ directionalLight.shadow.mapSize.width = 2048;
897
+ directionalLight.shadow.mapSize.height = 2048;
898
+ directionalLight.shadow.camera.near = 0.5;
899
+ directionalLight.shadow.camera.far = 50;
900
+ directionalLight.shadow.camera.left = -10;
901
+ directionalLight.shadow.camera.right = 10;
902
+ directionalLight.shadow.camera.top = 10;
903
+ directionalLight.shadow.camera.bottom = -10;
904
+ directionalLight.shadow.bias = -0.0001;
905
+ scene.add(directionalLight);
906
+
907
+ const fillLight = new THREE.DirectionalLight(0xaaccff, 0.5);
908
+ fillLight.position.set(-3, 3, -2);
909
+ scene.add(fillLight);
910
+
911
+ const rimLight = new THREE.DirectionalLight(0xffeedd, 0.4);
912
+ rimLight.position.set(0, 4, -5);
913
+ scene.add(rimLight);
914
+
915
+ renderer.toneMapping = THREE.ACESFilmicToneMapping;
916
+ renderer.toneMappingExposure = 1.0;
917
+ renderer.outputColorSpace = THREE.SRGBColorSpace;
918
+
919
+ renderer.setPixelRatio(window.devicePixelRatio);
920
+ renderer.setSize(width, height);
921
+ var container = document.getElementById('vis3d');
922
+ container.appendChild(renderer.domElement);
923
+
924
+ window.addEventListener('resize', onWindowResize);
925
+
926
+ controls = new OrbitControls(camera, renderer.domElement);
927
+ controls.minDistance = 1;
928
+ controls.maxDistance = 15;
929
+ controls.enableDamping = true;
930
+ controls.dampingFactor = 0.05;
931
+ controls.target.set(0, 1, 0);
932
+ fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
933
+
934
+ let isDragging = false;
935
+ let mouseDownTime = 0;
936
+
937
+ renderer.domElement.addEventListener('mousedown', () => {
938
+ isDragging = false;
939
+ mouseDownTime = Date.now();
940
+ });
941
+
942
+ renderer.domElement.addEventListener('mousemove', () => {
943
+ if (Date.now() - mouseDownTime > 150) {
944
+ isDragging = true;
945
+ }
946
+ });
947
+
948
+ renderer.domElement.addEventListener('mouseup', (e) => {
949
+ if (!isDragging && Date.now() - mouseDownTime < 300) {
950
+ if (modelsLoaded) {
951
+ isPlaying ? pauseAnimation() : playAnimation();
952
+ }
953
+ }
954
+ });
955
+
956
+ renderer.domElement.addEventListener('dblclick', () => {
957
+ if (modelsLoaded) {
958
+ pauseAnimation();
959
+ currentFrame = 0;
960
+ updateFrame();
961
+ }
962
+ });
963
+ }
964
+
965
+ function animate() {
966
+ requestAnimationFrame(animate);
967
+ if (controls && controls.enableDamping) {
968
+ controls.update();
969
+ }
970
+ renderer.render(scene, camera);
971
+ }
972
+
973
+ function onWindowResize() {
974
+ const width = window.innerWidth;
975
+ const height = window.innerHeight;
976
+ camera.aspect = width / height;
977
+ camera.updateProjectionMatrix();
978
+ renderer.setSize(width, height);
979
+ }
980
+
981
+ function computeOffsets(batchSize) {
982
+ const spacing = 2.0;
983
+ const total_width = (batchSize - 1) * spacing;
984
+ const start_x = -total_width / 2;
985
+ const offsets = [];
986
+ for (let i = 0; i < batchSize; i++) {
987
+ offsets.push(start_x + i * spacing);
988
+ }
989
+ return offsets;
990
+ }
991
+
992
+ </script>
993
+
994
+ <style>
995
+ /* Fullscreen dark mode base styles */
996
+ * {
997
+ margin: 0;
998
+ padding: 0;
999
+ box-sizing: border-box;
1000
+ }
1001
+
1002
+ html, body {
1003
+ width: 100%;
1004
+ height: 100%;
1005
+ overflow: hidden;
1006
+ background: #424242 !important;
1007
+ color: #e2e8f0;
1008
+ }
1009
+
1010
+ /* Fullscreen container for 3D scene */
1011
+ .fullscreen-container {
1012
+ position: fixed;
1013
+ top: 0;
1014
+ left: 0;
1015
+ width: 100vw;
1016
+ height: 100vh;
1017
+ background: #424242;
1018
+ overflow: hidden;
1019
+ }
1020
+
1021
+ #vis3d {
1022
+ position: absolute;
1023
+ top: 0;
1024
+ left: 0;
1025
+ width: 100%;
1026
+ height: 100%;
1027
+ background: #424242;
1028
+ }
1029
+
1030
+ #vis3d canvas {
1031
+ display: block;
1032
+ width: 100% !important;
1033
+ height: 100% !important;
1034
+ }
1035
+
1036
+ /* Floating caption overlay */
1037
+ .caption-overlay {
1038
+ position: absolute;
1039
+ top: 20px;
1040
+ left: 50%;
1041
+ transform: translateX(-50%);
1042
+ width: auto;
1043
+ max-width: 90%;
1044
+ z-index: 100;
1045
+ pointer-events: auto;
1046
+ }
1047
+
1048
+ .motion-info {
1049
+ background-color: rgba(45, 55, 72, 0.85);
1050
+ backdrop-filter: blur(10px);
1051
+ -webkit-backdrop-filter: blur(10px);
1052
+ border-radius: 20px;
1053
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.4);
1054
+ overflow: hidden;
1055
+ max-height: 40vh;
1056
+ overflow-y: auto;
1057
+ display: inline-block;
1058
+ }
1059
+
1060
+ /* Floating progress control panel */
1061
+ .control-overlay {
1062
+ position: absolute;
1063
+ bottom: 30px;
1064
+ left: 50%;
1065
+ transform: translateX(-50%);
1066
+ width: 80%;
1067
+ max-width: 600px;
1068
+ z-index: 100;
1069
+ background: rgba(0, 0, 0, 0.4);
1070
+ backdrop-filter: blur(8px);
1071
+ -webkit-backdrop-filter: blur(8px);
1072
+ padding: 15px 20px;
1073
+ border-radius: 12px;
1074
+ }
1075
+
1076
+ .control-row-minimal {
1077
+ display: flex;
1078
+ align-items: center;
1079
+ gap: 20px;
1080
+ }
1081
+
1082
+ .progress-container {
1083
+ flex: 1;
1084
+ }
1085
+
1086
+ .progress-slider-minimal {
1087
+ width: 100%;
1088
+ height: 8px;
1089
+ border-radius: 4px;
1090
+ background: rgba(255, 255, 255, 0.3);
1091
+ outline: none;
1092
+ cursor: pointer;
1093
+ -webkit-appearance: none;
1094
+ appearance: none;
1095
+ }
1096
+
1097
+ .progress-slider-minimal::-webkit-slider-runnable-track {
1098
+ width: 100%;
1099
+ height: 8px;
1100
+ border-radius: 4px;
1101
+ background: rgba(255, 255, 255, 0.3);
1102
+ }
1103
+
1104
+ .progress-slider-minimal::-webkit-slider-thumb {
1105
+ -webkit-appearance: none;
1106
+ appearance: none;
1107
+ width: 20px;
1108
+ height: 20px;
1109
+ border-radius: 50%;
1110
+ background: #4a9eff;
1111
+ cursor: pointer;
1112
+ border: 2px solid white;
1113
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
1114
+ margin-top: -6px;
1115
+ }
1116
+
1117
+ .progress-slider-minimal::-moz-range-track {
1118
+ width: 100%;
1119
+ height: 8px;
1120
+ border-radius: 4px;
1121
+ background: rgba(255, 255, 255, 0.3);
1122
+ }
1123
+
1124
+ .progress-slider-minimal::-moz-range-thumb {
1125
+ width: 20px;
1126
+ height: 20px;
1127
+ border-radius: 50%;
1128
+ background: #4a9eff;
1129
+ cursor: pointer;
1130
+ border: 2px solid white;
1131
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.4);
1132
+ }
1133
+
1134
+ .frame-counter {
1135
+ font-family: 'SF Mono', 'Consolas', monospace;
1136
+ font-size: 14px;
1137
+ font-weight: 500;
1138
+ color: white;
1139
+ text-shadow: 0 1px 3px rgba(0, 0, 0, 0.5);
1140
+ white-space: nowrap;
1141
+ min-width: 80px;
1142
+ text-align: right;
1143
+ }
1144
+
1145
+ /* Loading overlay */
1146
+ .loading-overlay {
1147
+ position: absolute;
1148
+ top: 50%;
1149
+ left: 50%;
1150
+ transform: translate(-50%, -50%);
1151
+ background: rgba(0, 0, 0, 0.7);
1152
+ backdrop-filter: blur(8px);
1153
+ -webkit-backdrop-filter: blur(8px);
1154
+ color: white;
1155
+ padding: 15px 25px;
1156
+ border-radius: 10px;
1157
+ font-size: 14px;
1158
+ z-index: 200;
1159
+ display: flex;
1160
+ align-items: center;
1161
+ gap: 10px;
1162
+ }
1163
+
1164
+ .loading-overlay.hidden {
1165
+ display: none;
1166
+ }
1167
+
1168
+ .loading-overlay.complete {
1169
+ background: rgba(76, 175, 80, 0.85);
1170
+ }
1171
+
1172
+ /* Caption content styles */
1173
+ .loading {
1174
+ padding: 10px 18px;
1175
+ text-align: center;
1176
+ color: #a0aec0;
1177
+ font-style: italic;
1178
+ white-space: nowrap;
1179
+ }
1180
+
1181
+ .captions-section {
1182
+ padding: 12px 20px;
1183
+ white-space: nowrap;
1184
+ }
1185
+
1186
+ .caption-item {
1187
+ background: transparent;
1188
+ border: none;
1189
+ border-radius: 0;
1190
+ margin-bottom: 6px;
1191
+ padding: 0;
1192
+ color: #f0f4f8;
1193
+ font-size: 1em;
1194
+ font-weight: 500;
1195
+ line-height: 1.5;
1196
+ text-align: center;
1197
+ }
1198
+
1199
+ .caption-item:last-child {
1200
+ margin-bottom: 0;
1201
+ }
1202
+ </style>
1203
+
1204
+ </body>
1205
+ </html>