chingshuai commited on
Commit
bcc06ab
·
1 Parent(s): f319393

copy init code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .gitignore +25 -0
  3. configs/base/config.yml +37 -0
  4. examples/example_prompts/example_subset.json +61 -0
  5. gradio_app.py +763 -27
  6. hymotion/network/attention.py +110 -0
  7. hymotion/network/bricks.py +46 -0
  8. hymotion/network/encoders.py +121 -0
  9. hymotion/network/hymotion_mmdit.py +636 -0
  10. hymotion/network/modulate_layers.py +49 -0
  11. hymotion/network/positional_encoding.py +174 -0
  12. hymotion/network/text_encoders/model_constants.py +8 -0
  13. hymotion/network/text_encoders/text_encoder.py +293 -0
  14. hymotion/network/token_refiner.py +192 -0
  15. hymotion/pipeline/body_model.py +412 -0
  16. hymotion/pipeline/motion_diffusion.py +673 -0
  17. hymotion/prompt_engineering/model_constants.py +42 -0
  18. hymotion/prompt_engineering/prompt_rewrite.py +304 -0
  19. hymotion/utils/configs.py +344 -0
  20. hymotion/utils/geometry.py +856 -0
  21. hymotion/utils/loaders.py +184 -0
  22. hymotion/utils/misc.py +136 -0
  23. hymotion/utils/motion_process.py +154 -0
  24. hymotion/utils/path.py +168 -0
  25. hymotion/utils/smplh2fbx.py +585 -0
  26. hymotion/utils/smplh2woodfbx.py +702 -0
  27. hymotion/utils/t2m_runtime.py +378 -0
  28. hymotion/utils/type_converter.py +22 -0
  29. hymotion/utils/visualize_mesh_web.py +342 -0
  30. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp +3 -0
  31. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp +3 -0
  32. scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp +3 -0
  33. scripts/gradio/static/assets/dump_wooden/faces.bin +3 -0
  34. scripts/gradio/static/assets/dump_wooden/j_template.bin +3 -0
  35. scripts/gradio/static/assets/dump_wooden/joint_names.json +54 -0
  36. scripts/gradio/static/assets/dump_wooden/joints.ply +0 -0
  37. scripts/gradio/static/assets/dump_wooden/keypoints.bin +3 -0
  38. scripts/gradio/static/assets/dump_wooden/kintree.bin +3 -0
  39. scripts/gradio/static/assets/dump_wooden/skinIndice.bin +3 -0
  40. scripts/gradio/static/assets/dump_wooden/skinWeights.bin +3 -0
  41. scripts/gradio/static/assets/dump_wooden/uvs.bin +3 -0
  42. scripts/gradio/static/assets/dump_wooden/v_template.bin +3 -0
  43. scripts/gradio/static/scripts3d/create_ground.js +191 -0
  44. scripts/gradio/static/scripts3d/create_scene.js +195 -0
  45. scripts/gradio/static/scripts3d/draw_skeleton.js +121 -0
  46. scripts/gradio/static/scripts3d/load_smpl.js +126 -0
  47. scripts/gradio/static/scripts3d/load_wooden.js +167 -0
  48. scripts/gradio/templates/element/blank.html +53 -0
  49. scripts/gradio/templates/error_file_not_found.html +64 -0
  50. scripts/gradio/templates/index_smpl_gradio.html +938 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 3rdparty/fbxsdkpy-2020.1.post2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
+ assets/wooden_models/*.fbx filter=lfs diff=lfs merge=lfs -text
38
+ assets/wooden_models/boy_Rigging_smplx_tex.fbm/Boy_lambert4_BaseColor.png filter=lfs diff=lfs merge=lfs -text
39
+ *.webp filter=lfs diff=lfs merge=lfs -text
40
+ *.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cache_config.json
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.pyw
7
+ *.pyz
8
+ *.pywz
9
+ *.pyzw
10
+ *.pyzwz
11
+ cache
12
+
13
+ .vscode/
14
+
15
+ ckpts/*
16
+ !ckpts/README.md
17
+ assets/body_models/*
18
+ !assets/body_models/README.md
19
+ scripts/gradio/static/assets/dump_smplh
20
+ scripts/gradio/static/assets/export_wooden_to_js.py
21
+
22
+ test_*/
23
+ debug
24
+ tencent
25
+ output
configs/base/config.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ network_module: hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT
2
+ network_module_args:
3
+ apply_rope_to_single_branch: false
4
+ ctxt_input_dim: 4096
5
+ dropout: 0.0
6
+ feat_dim: 1024
7
+ input_dim: 201
8
+ mask_mode: narrowband
9
+ mlp_ratio: 4.0
10
+ num_heads: 16
11
+ num_layers: 18
12
+ time_factor: 1000.0
13
+ vtxt_input_dim: 768
14
+ train_pipeline: hymotion/pipeline/motion_diffusion.MotionFlowMatching
15
+ train_pipeline_args:
16
+ enable_ctxt_null_feat: true
17
+ enable_special_game_feat: true
18
+ infer_noise_scheduler_cfg:
19
+ validation_steps: 50
20
+ losses_cfg:
21
+ recons:
22
+ name: SmoothL1Loss
23
+ weight: 1.0
24
+ noise_scheduler_cfg:
25
+ method: euler
26
+ output_mesh_fps: 30
27
+ random_generator_on_gpu: true
28
+ test_cfg:
29
+ mean_std_dir: ./stats/
30
+ text_guidance_scale: 5.0
31
+ text_encoder_cfg:
32
+ llm_type: qwen3
33
+ max_length_llm: 128
34
+ text_encoder_module: hymotion/network/text_encoders/text_encoder.HYTextModel
35
+ train_cfg:
36
+ cond_mask_prob: 0.1
37
+ train_frames: 360
examples/example_prompts/example_subset.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_prompts_subset": [
3
+ "A person jumps upward with both legs twice.#90#none#001",
4
+ "A person jumps on their right leg.#90#none#002",
5
+ "A person climbs upward, moving up the slope.#60#none#003",
6
+ "A person climbs an obstacle.#60#none#004",
7
+ "A person walks forward.#120#none#005",
8
+ "A person walks forward, moving arms and legs while looking left and right.#180#none#006",
9
+ "A person walks unsteadily, then slowly sits down.#150#none#007",
10
+ "A person turns backward 180 degrees, then walks forward.#120#none#008",
11
+ "A person walks in a catwalk style, swinging their left arm while placing their right hand on their hip.#180#none#009",
12
+ "A person squats down on tiptoe#120#none#010",
13
+ "A person sits down on a chair.#90#none#011",
14
+ "A person runs forward.#60#none#012",
15
+ "A person jumps up.#90#none#013",
16
+ "A person jumps forward lightly, taking two steps.#69#none#014",
17
+ "A person shoots a basketball.#60#none#015",
18
+ "A person finishes freestyle swimming, then surfaces.#120#none#016",
19
+ "A person swings a golf club, hitting the ball forward.#111#none#017",
20
+ "A person runs forward, then kicks a soccer ball.#60#none#018",
21
+ "A person walks on a tightrope.#180#none#019",
22
+ "A person performs a yoga camel pose, extending their back and lifting their chest.#210#none#020",
23
+ "A person performs a sit-up, holding their head with both hands.#150#none#021",
24
+ "A person performs a lunge stretch, hands on hips.#150#none#022",
25
+ "A person performs a deadlift, lifting a barbell from the ground.#150#none#023",
26
+ "A person marches in place, swinging their arms forward and backward.#210#none#024",
27
+ "A person perform a squat, not standing up#93#none#025",
28
+ "A person performs a squat#93#none#026",
29
+ "A person performs a front arm raise, then does a squat.#93#none#027",
30
+ "A person performs a squat, raising both arms forward.#240#none#028",
31
+ "A person does a squat, balling both hands into fists, lowering into a squat, then standing up.#195#none#029",
32
+ "A person plays the piano.#270#none#030",
33
+ "A person dances bachata, executing rhythmic hip movements and footwork.#240#none#031",
34
+ "A person plays the drums while sitting down, with wide, crossing arm movements.#90#none#032",
35
+ "A person plays the drums while sitting down, with arms spreading wide and then crossing over.#90#none#033",
36
+ "A person dances jazz, jumping rhythmically.#240#none#034",
37
+ "A person practices tai chi, performing slow, controlled movements.#270#none#035",
38
+ "A person waves their right hand, sitting on a beach chair.#71#none#036",
39
+ "A person was sweeping the floor with their head down.#180#none#037",
40
+ "A person picks up an object from ground#117#none#038",
41
+ "A person picks up an object from lower ground with two hands#99#none#039",
42
+ "A person picks up an object from lower ground with two hands, and lifts over head#126#none#040",
43
+ "A person speaks, gesturing with both hands.#75#none#041",
44
+ "A person lies on a bed, reading a book.#180#none#042",
45
+ "A person bends down to pick up an object, then stands up straight.#150#none#043",
46
+ "A person flips the wok#61#none#044",
47
+ "A person rolls over while lying down.#60#none#045",
48
+ "A person walks forward, holding a tray at shoulder height with one hand.#93#none#046",
49
+ "A person stands up from the chair, then stretches the arms.#300#none#047",
50
+ "A person turns to evade.#61#none#048",
51
+ "A person collapses to the ground after being hit.#60#none#049",
52
+ "A person swings a sword forward.#60#none#050",
53
+ "A person attacks, holding a shield in the right hand and a sword in the left.#45#none#051",
54
+ "A person walks like a zombie, dragging their feet forward.#120#none#052",
55
+ "A person performs a taekwondo kick, extending their leg forcefully.#60#none#053",
56
+ "A person blocks with a shield.#60#none#054",
57
+ "A person lifts a long gun, then walks forward slowly.#90#none#055",
58
+ "A person stumbles, being hit.#45#none#056",
59
+ "A person assumes a boxing stance, then shifts weight to the right and punches with the right hand.#60#none#057"
60
+ ]
61
+ }
gradio_app.py CHANGED
@@ -1,28 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import sys
3
-
4
- try:
5
- import torch
6
- torch_version = torch.__version__
7
- except ImportError:
8
- print("torch not found, please install it")
9
- torch_version = "not found"
10
-
11
- try:
12
- import fbx
13
- try:
14
- fbx_version = fbx.__version__
15
- except AttributeError:
16
- # fbx module doesn't have __version__ attribute
17
- fbx_version = "installed (version unknown)"
18
- except ImportError:
19
- print("fbx not found, please install it")
20
- fbx_version = "not found"
21
-
22
- def greet(name):
23
- python_version = sys.version
24
- version = torch_version + " fbx version: " + fbx_version
25
- return "Hello " + name + "!! torch version: " + version + " python version: " + python_version
26
-
27
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
28
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # we should use gradio==5.38.2
2
+ import argparse
3
+ import codecs as cs
4
+ import json
5
+ import os
6
+ import os.path as osp
7
+ import random
8
+ import re
9
+ import textwrap
10
+ from typing import List, Optional, Tuple, Union
11
+
12
  import gradio as gr
13
+ import torch
14
+
15
+ from hymotion.utils.t2m_runtime import T2MRuntime
16
+
17
+ NUM_WORKERS = torch.cuda.device_count() if torch.cuda.is_available() else 1
18
+
19
+
20
+ # define data sources
21
+ DATA_SOURCES = {
22
+ "example_prompts": "examples/example_prompts/example_subset.json",
23
+ }
24
+
25
+ # create interface
26
+ APP_CSS = """
27
+ :root{
28
+ --primary-start:#667eea; --primary-end:#764ba2;
29
+ --secondary-start:#4facfe; --secondary-end:#00f2fe;
30
+ --accent-start:#f093fb; --accent-end:#f5576c;
31
+ --page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%);
32
+ --card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%);
33
+ --radius:12px;
34
+ --iframe-bg:#ffffff;
35
+ }
36
+
37
+ /* Dark mode variables */
38
+ [data-theme="dark"], .dark {
39
+ --page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%);
40
+ --card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%);
41
+ --text-primary:#f7fafc;
42
+ --text-secondary:#e2e8f0;
43
+ --border-color:#4a5568;
44
+ --input-bg:#374151;
45
+ --input-border:#4a5568;
46
+ --iframe-bg:#1a1a2e;
47
+ }
48
+
49
+ /* Page and card */
50
+ .gradio-container{
51
+ background:var(--page-bg) !important;
52
+ min-height:100vh !important;
53
+ color:var(--text-primary, #333) !important;
54
+ }
55
+
56
+ .main-header{
57
+ background:transparent !important; border:none !important; box-shadow:none !important;
58
+ padding:0 !important; margin:10px 0 16px !important;
59
+ text-align:center !important;
60
+ }
61
+
62
+ .main-header h1, .main-header p, .main-header li {
63
+ color:var(--text-primary, #333) !important;
64
+ }
65
+
66
+ .left-panel,.right-panel{
67
+ background:var(--card-bg) !important;
68
+ border:1px solid var(--border-color, #e9ecef) !important;
69
+ border-radius:15px !important;
70
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
71
+ padding:24px !important;
72
+ }
73
+
74
+ .gradio-accordion{
75
+ border:1px solid var(--border-color, #e1e5e9) !important;
76
+ border-radius:var(--radius) !important;
77
+ margin:12px 0 !important; background:transparent !important;
78
+ }
79
+
80
+ .gradio-accordion summary{
81
+ background:transparent !important;
82
+ padding:14px 18px !important;
83
+ font-weight:600 !important;
84
+ color:var(--text-primary, #495057) !important;
85
+ }
86
+
87
+ .gradio-group{
88
+ background:transparent !important; border:none !important;
89
+ border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important;
90
+ }
91
+
92
+ /* Input class style - dark mode adaptation */
93
+ .gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{
94
+ border-radius:8px !important;
95
+ border:2px solid var(--input-border, #e9ecef) !important;
96
+ background:var(--input-bg, #fff) !important;
97
+ color:var(--text-primary, #333) !important;
98
+ transition:.2s all !important;
99
+ }
100
+
101
+ .gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{
102
+ border-color:var(--primary-start) !important;
103
+ box-shadow:0 0 0 3px rgba(102,126,234,.1) !important;
104
+ }
105
+
106
+ .gradio-slider input[type="range"]{
107
+ background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important;
108
+ border-radius:10px !important;
109
+ }
110
+
111
+ .gradio-checkbox input[type="checkbox"]{
112
+ border-radius:4px !important;
113
+ border:2px solid var(--input-border, #e9ecef) !important;
114
+ transition:.2s all !important;
115
+ }
116
+
117
+ .gradio-checkbox input[type="checkbox"]:checked{
118
+ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important;
119
+ border-color:var(--primary-start) !important;
120
+ }
121
+
122
+ /* Label text color adaptation */
123
+ .gradio-textbox label, .gradio-dropdown label, .gradio-slider label,
124
+ .gradio-checkbox label, .gradio-html label {
125
+ color:var(--text-primary, #333) !important;
126
+ }
127
+
128
+ .gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info,
129
+ .gradio-checkbox .info {
130
+ color:var(--text-secondary, #666) !important;
131
+ }
132
+
133
+ /* Status information - dark mode adaptation */
134
+ .gradio-textbox[data-testid*="状态信息"] input{
135
+ background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important;
136
+ border:2px solid var(--input-border, #dee2e6) !important;
137
+ color:var(--text-primary, #495057) !important;
138
+ font-weight:500 !important;
139
+ }
140
+
141
+ /* Button base class and variant */
142
+ .generate-button,.rewrite-button,.dice-button{
143
+ border:none !important; color:#fff !important; font-weight:600 !important;
144
+ border-radius:8px !important; transition:.3s all !important;
145
+ box-shadow:0 4px 15px rgba(0,0,0,.12) !important;
146
+ }
147
+
148
+ .generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; }
149
+ .rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; }
150
+ .dice-button{
151
+ background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important;
152
+ height:40px !important;
153
+ }
154
+
155
+ .generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; }
156
+ .dice-button:hover{
157
+ transform:scale(1.05) !important;
158
+ box-shadow:0 4px 12px rgba(240,147,251,.28) !important;
159
+ }
160
+
161
+ .dice-container{
162
+ display:flex !important;
163
+ align-items:flex-end !important;
164
+ justify-content:center !important;
165
+ }
166
+
167
+ /* Right panel clipping overflow, avoid double scrollbars */
168
+ .right-panel{
169
+ background:var(--card-bg) !important;
170
+ border:1px solid var(--border-color, #e9ecef) !important;
171
+ border-radius:15px !important;
172
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
173
+ padding:24px !important; overflow:hidden !important;
174
+ }
175
+
176
+ /* Main content row - ensure equal heights */
177
+ .main-row {
178
+ display: flex !important;
179
+ align-items: stretch !important;
180
+ }
181
+
182
+ /* Flask area - match left panel height */
183
+ .flask-display{
184
+ padding:0 !important; margin:0 !important; border:none !important;
185
+ box-shadow:none !important; background:var(--iframe-bg) !important;
186
+ border-radius:10px !important; position:relative !important;
187
+ height:100% !important; min-height:750px !important;
188
+ display:flex !important; flex-direction:column !important;
189
+ }
190
+
191
+ .flask-display iframe{
192
+ width:100% !important; flex:1 !important; min-height:750px !important;
193
+ border:none !important; border-radius:10px !important; display:block !important;
194
+ background:var(--iframe-bg) !important;
195
+ }
196
+
197
+ /* Right panel should stretch to match left panel */
198
+ .right-panel{
199
+ background:var(--card-bg) !important;
200
+ border:1px solid var(--border-color, #e9ecef) !important;
201
+ border-radius:15px !important;
202
+ box-shadow:0 4px 20px rgba(0,0,0,.08) !important;
203
+ padding:24px !important; overflow:hidden !important;
204
+ display:flex !important; flex-direction:column !important;
205
+ }
206
+
207
+ /* Ensure dropdown menu is visible in dark mode */
208
+ [data-theme="dark"] .gradio-dropdown .wrap,
209
+ .dark .gradio-dropdown .wrap {
210
+ background:var(--input-bg) !important;
211
+ color:var(--text-primary) !important;
212
+ }
213
+
214
+ [data-theme="dark"] .gradio-dropdown .option,
215
+ .dark .gradio-dropdown .option {
216
+ background:var(--input-bg) !important;
217
+ color:var(--text-primary) !important;
218
+ }
219
+
220
+ [data-theme="dark"] .gradio-dropdown .option:hover,
221
+ .dark .gradio-dropdown .option:hover {
222
+ background:var(--border-color) !important;
223
+ }
224
+
225
+ .footer{
226
+ text-align:center !important;
227
+ margin-top:20px !important;
228
+ padding:10px !important;
229
+ color:var(--text-secondary, #666) !important;
230
+ }
231
+ """
232
+
233
+ HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground"
234
+
235
+ FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*"
236
+
237
+
238
+ def load_examples_from_txt(txt_path: str):
239
+ """Load examples from txt file."""
240
+
241
+ def _parse_line(line: str) -> Optional[Tuple[str, float]]:
242
+ line = line.strip()
243
+ if line and not line.startswith("#"):
244
+ parts = line.split("#")
245
+ if len(parts) >= 2:
246
+ text = parts[0].strip()
247
+ duration = int(parts[1]) / 20.0
248
+ else:
249
+ text = line.strip()
250
+ duration = 5.0
251
+ return text, duration
252
+ return None
253
+
254
+ examples: List[Tuple[str, float]] = []
255
+ if os.path.exists(txt_path):
256
+ try:
257
+ if txt_path.endswith(".txt"):
258
+ with cs.open(txt_path, "r", encoding="utf-8") as f:
259
+ lines = f.readlines()
260
+ for line in lines:
261
+ result = _parse_line(line)
262
+ if result is None:
263
+ continue
264
+ text, duration = result
265
+ examples.append((text, duration))
266
+ elif txt_path.endswith(".json"):
267
+ with cs.open(txt_path, "r", encoding="utf-8") as f:
268
+ lines = json.load(f)
269
+ for key, value in lines.items():
270
+ if "_raw_chn" in key or "GENERATE_PROMPT_FORMAT" in key:
271
+ continue
272
+ for line in value:
273
+ result = _parse_line(line)
274
+ if result is None:
275
+ continue
276
+ text, duration = result
277
+ examples.append((text, duration))
278
+ print(f">>> Loaded {len(examples)} examples from {txt_path}")
279
+ except Exception as e:
280
+ print(f">>> Failed to load examples from {txt_path}: {e}")
281
+ else:
282
+ print(f">>> Examples file not found: {txt_path}")
283
+
284
+ return examples
285
+
286
+
287
+ class T2MGradioUI:
288
+ def __init__(self, runtime: T2MRuntime, args: argparse.Namespace):
289
+ self.runtime = runtime
290
+ self.args = args
291
+
292
+ # Check if rewrite is available:
293
+ # - prompt_engineering_host must be provided
294
+ # - disable_rewrite must not be set
295
+ print(f">>> args: {vars(args)}")
296
+ self.rewrite_available = (
297
+ args.prompt_engineering_host is not None
298
+ and args.prompt_engineering_host.strip() != ""
299
+ and not args.disable_rewrite
300
+ )
301
+
302
+ self.all_example_data = {}
303
+ self._init_example_data()
304
+
305
+ def _init_example_data(self):
306
+ for source_name, file_path in DATA_SOURCES.items():
307
+ examples = load_examples_from_txt(file_path)
308
+ if examples:
309
+ self.all_example_data[source_name] = examples
310
+ else:
311
+ # provide default examples as fallback
312
+ self.all_example_data[source_name] = [
313
+ ("Twist at the waist and punch across the body.", 3.0),
314
+ ("A person is running then takes big leap.", 3.0),
315
+ ("A person holds a railing and walks down a set of stairs.", 5.0),
316
+ (
317
+ "A man performs a fluid and rhythmic hip-hop style dance, incorporating body waves, arm gestures, and side steps.",
318
+ 5.0,
319
+ ),
320
+ ]
321
+ print(f">>> Loaded data sources: {list(self.all_example_data.keys())}")
322
+
323
+ def _get_header_text(self):
324
+ return HEADER_BASE_MD
325
+
326
+ def _generate_random_seeds(self):
327
+ seeds = [random.randint(0, 999) for _ in range(4)]
328
+ return ",".join(map(str, seeds))
329
+
330
+ def _prompt_engineering(
331
+ self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True
332
+ ):
333
+ if not text.strip():
334
+ return "", gr.update(interactive=False), gr.update()
335
+
336
+ call_llm = enable_rewrite or enable_duration_est
337
+ if not call_llm:
338
+ print(f"\t>>> Using original duration and original text...")
339
+ predicted_duration = duration
340
+ rewritten_text = text
341
+ else:
342
+ print(f"\t>>> Using LLM to estimate duration/rewrite text...")
343
+ try:
344
+ predicted_duration, rewritten_text = self.runtime.rewrite_text_and_infer_time(text=text)
345
+ except Exception as e:
346
+ print(f"\t>>> Text rewriting/duration prediction failed: {e}")
347
+ return (
348
+ f"❌ Text rewriting/duration prediction failed: {str(e)}",
349
+ gr.update(interactive=False),
350
+ gr.update(),
351
+ )
352
+ if not enable_rewrite:
353
+ rewritten_text = text
354
+ if not enable_duration_est:
355
+ predicted_duration = duration
356
+
357
+ return rewritten_text, gr.update(interactive=True), gr.update(value=predicted_duration)
358
+
359
+ def _generate_motion(
360
+ self,
361
+ original_text: str,
362
+ rewritten_text: str,
363
+ seed_input: str,
364
+ duration: float,
365
+ cfg_scale: float,
366
+ ) -> Tuple[str, List[str]]:
367
+ # When rewrite is not available, use original_text directly
368
+ if not self.rewrite_available:
369
+ text_to_use = original_text.strip()
370
+ if not text_to_use:
371
+ return "Error: Input text is empty, please enter text first", []
372
+ else:
373
+ text_to_use = rewritten_text.strip()
374
+ if not text_to_use:
375
+ return "Error: Rewritten text is empty, please rewrite the text first", []
376
+
377
+ try:
378
+ fbx_ok = getattr(self.runtime, "fbx_available", False)
379
+ req_format = "fbx" if fbx_ok else "dict"
380
+ html, fbx_files, _ = self.runtime.generate_motion(
381
+ text=text_to_use,
382
+ seeds_csv=seed_input,
383
+ duration=duration,
384
+ cfg_scale=cfg_scale,
385
+ output_format=req_format,
386
+ original_text=original_text,
387
+ output_dir=self.args.output_dir,
388
+ )
389
+ iframe_html = f"""
390
+ <iframe
391
+ src="{html}"
392
+ style="width: 100%; height: 800px; border: none; display: block; background: var(--iframe-bg); border-radius: 10px;"
393
+ frameborder="0"
394
+ scrolling="auto"
395
+ allowfullscreen
396
+ ></iframe>
397
+ """
398
+ return iframe_html, fbx_files
399
+ except Exception as e:
400
+ print(f"\t>>> Motion generation failed: {e}")
401
+ return (
402
+ f"❌ Motion generation failed: {str(e)}\n\nPlease check the input parameters or try again later",
403
+ [],
404
+ )
405
+
406
+ def _get_example_choices(self):
407
+ """Get all example choices from all data sources"""
408
+ choices = ["Custom Input"]
409
+ for source_name in self.all_example_data:
410
+ example_data = self.all_example_data[source_name]
411
+ for text, _ in example_data:
412
+ display_text = f"{text[:50]}..." if len(text) > 50 else text
413
+ choices.append(display_text)
414
+ return choices
415
+
416
+ def _on_example_select(self, selected_example):
417
+ """When selecting an example, the callback function"""
418
+ if selected_example == "Custom Input":
419
+ return "", self._generate_random_seeds(), gr.update()
420
+ else:
421
+ # find the corresponding example from all data sources
422
+ for source_name in self.all_example_data:
423
+ example_data = self.all_example_data[source_name]
424
+ for text, duration in example_data:
425
+ display_text = f"{text[:50]}..." if len(text) > 50 else text
426
+ if display_text == selected_example:
427
+ return text, self._generate_random_seeds(), gr.update(value=duration)
428
+ return "", self._generate_random_seeds(), gr.update()
429
+
430
+ def build_ui(self):
431
+ with gr.Blocks(css=APP_CSS) as demo:
432
+ self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
433
+
434
+ with gr.Row():
435
+ # Left control panel
436
+ with gr.Column(scale=2, elem_classes=["left-panel"]):
437
+ # Input textbox
438
+ self.text_input = gr.Textbox(
439
+ label="📝 Input Text",
440
+ placeholder="Enter text to generate motion, support Chinese and English text input.",
441
+ )
442
+ # Rewritten textbox
443
+ self.rewritten_text = gr.Textbox(
444
+ label="✏️ Rewritten Text",
445
+ placeholder="Rewritten text will be displayed here, you can further edit",
446
+ interactive=True,
447
+ visible=False,
448
+ )
449
+ # Duration slider
450
+ self.duration_slider = gr.Slider(
451
+ minimum=0.5,
452
+ maximum=12,
453
+ value=5.0,
454
+ step=0.1,
455
+ label="⏱️ Action Duration (seconds)",
456
+ info="Feel free to adjust the action duration",
457
+ )
458
+
459
+ # Execute buttons
460
+ with gr.Row():
461
+ if self.rewrite_available:
462
+ self.rewrite_btn = gr.Button(
463
+ "🔄 Rewrite Text",
464
+ variant="secondary",
465
+ size="lg",
466
+ elem_classes=["rewrite-button"],
467
+ )
468
+ else:
469
+ # Create a hidden/disabled placeholder button
470
+ self.rewrite_btn = gr.Button(
471
+ "🔄 Rewrite Text (Unavailable)",
472
+ variant="secondary",
473
+ size="lg",
474
+ elem_classes=["rewrite-button"],
475
+ interactive=False,
476
+ visible=False,
477
+ )
478
+
479
+ self.generate_btn = gr.Button(
480
+ "🚀 Generate Motion",
481
+ variant="primary",
482
+ size="lg",
483
+ elem_classes=["generate-button"],
484
+ interactive=not self.rewrite_available, # Enable directly if rewrite not available
485
+ )
486
+
487
+ if not self.rewrite_available:
488
+ gr.Markdown(
489
+ "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
490
+ )
491
+
492
+ # Advanced settings
493
+ with gr.Accordion("🔧 Advanced Settings", open=False):
494
+ self._build_advanced_settings()
495
+
496
+ # Example selection dropdown
497
+ self.example_dropdown = gr.Dropdown(
498
+ choices=self._get_example_choices(),
499
+ value="Custom Input",
500
+ label="📚 Test Examples",
501
+ info="Select a preset example or input your own text above",
502
+ interactive=True,
503
+ )
504
+
505
+ # Status message depends on whether rewrite is available
506
+ if self.rewrite_available:
507
+ status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first"
508
+ else:
509
+ status_msg = "Enter your text and click [🚀 Generate Motion] directly."
510
+
511
+ self.status_output = gr.Textbox(
512
+ label="📊 Status Information",
513
+ value=status_msg,
514
+ )
515
+
516
+ # FBX Download section
517
+ with gr.Row(visible=False) as self.fbx_download_row:
518
+ if getattr(self.runtime, "fbx_available", False):
519
+ self.fbx_files = gr.File(
520
+ label="📦 Download FBX Files",
521
+ file_count="multiple",
522
+ interactive=False,
523
+ )
524
+ else:
525
+ self.fbx_files = gr.State([])
526
+
527
+ # Right display area
528
+ with gr.Column(scale=3):
529
+ self.output_display = gr.HTML(show_label=False, elem_classes=["flask-display"])
530
+
531
+ # Footer
532
+ gr.Markdown(FOOTER_MD, elem_classes=["footer"])
533
+
534
+ self._bind_events()
535
+ demo.load(fn=self._get_header_text, outputs=[self.header_md])
536
+ return demo
537
+
538
+ def _build_advanced_settings(self):
539
+ # Only show rewrite options if rewrite is available
540
+ if self.rewrite_available:
541
+ with gr.Group():
542
+ gr.Markdown("### 🔄 Text Rewriting Options")
543
+ with gr.Row():
544
+ self.enable_rewrite = gr.Checkbox(
545
+ label="Enable Text Rewriting",
546
+ value=True,
547
+ info="Automatically optimize text prompt to get better motion generation",
548
+ )
549
+
550
+ with gr.Group():
551
+ gr.Markdown("### ⏱️ Duration Settings")
552
+ self.enable_duration_est = gr.Checkbox(
553
+ label="Enable Duration Estimation",
554
+ value=True,
555
+ info="Automatically estimate the duration of the motion",
556
+ )
557
+ else:
558
+ # Create hidden placeholders with default values (disabled)
559
+ self.enable_rewrite = gr.Checkbox(
560
+ label="Enable Text Rewriting",
561
+ value=False,
562
+ visible=False,
563
+ )
564
+ self.enable_duration_est = gr.Checkbox(
565
+ label="Enable Duration Estimation",
566
+ value=False,
567
+ visible=False,
568
+ )
569
+ with gr.Group():
570
+ gr.Markdown("### ⚠️ Prompt Engineering Unavailable")
571
+ gr.Markdown(
572
+ "Text rewriting and duration estimation are not available. "
573
+ "Your input text and duration will be used directly."
574
+ )
575
+
576
+ with gr.Group():
577
+ gr.Markdown("### ⚙️ Generation Parameters")
578
+ with gr.Row():
579
+ with gr.Column(scale=3):
580
+ self.seed_input = gr.Textbox(
581
+ label="🎯 Random Seed List (comma separated)",
582
+ value="0,1,2,3",
583
+ placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)",
584
+ info="Random seeds control the diversity of generated motions",
585
+ )
586
+ with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]):
587
+ self.dice_btn = gr.Button(
588
+ "🎲 Lucky Button",
589
+ variant="secondary",
590
+ size="sm",
591
+ elem_classes=["dice-button"],
592
+ )
593
+
594
+ self.cfg_slider = gr.Slider(
595
+ minimum=1,
596
+ maximum=10,
597
+ value=5.0,
598
+ step=0.1,
599
+ label="⚙️ CFG Strength",
600
+ info="Text fidelity: higher = more faithful to the prompt",
601
+ )
602
+
603
+ def _bind_events(self):
604
+ # Generate random seeds
605
+ self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
606
+
607
+ # Bind example selection event
608
+ self.example_dropdown.change(
609
+ fn=self._on_example_select,
610
+ inputs=[self.example_dropdown],
611
+ outputs=[self.text_input, self.seed_input, self.duration_slider],
612
+ )
613
+
614
+ # Rewrite text logic (only bind when rewrite is available)
615
+ if self.rewrite_available:
616
+ self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then(
617
+ self._prompt_engineering,
618
+ inputs=[
619
+ self.text_input,
620
+ self.duration_slider,
621
+ self.enable_rewrite,
622
+ self.enable_duration_est,
623
+ ],
624
+ outputs=[self.rewritten_text, self.generate_btn, self.duration_slider],
625
+ ).then(
626
+ fn=lambda: (
627
+ gr.update(visible=True),
628
+ "Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
629
+ ),
630
+ outputs=[self.rewritten_text, self.status_output],
631
+ )
632
+
633
+ # Generate motion logic
634
+ self.generate_btn.click(
635
+ fn=lambda: "Generating motion, please wait... (It takes some extra time to start the renderer for the first generation)",
636
+ outputs=[self.status_output],
637
+ ).then(
638
+ self._generate_motion,
639
+ inputs=[
640
+ self.text_input,
641
+ self.rewritten_text,
642
+ self.seed_input,
643
+ self.duration_slider,
644
+ self.cfg_slider,
645
+ ],
646
+ outputs=[self.output_display, self.fbx_files],
647
+ concurrency_limit=NUM_WORKERS,
648
+ ).then(
649
+ fn=lambda fbx_list: (
650
+ (
651
+ "🎉 Motion generation completed! You can view the motion visualization result on the right. FBX files are ready for download."
652
+ if fbx_list
653
+ else "🎉 Motion generation completed! You can view the motion visualization result on the right"
654
+ ),
655
+ gr.update(visible=bool(fbx_list)),
656
+ ),
657
+ inputs=[self.fbx_files],
658
+ outputs=[self.status_output, self.fbx_download_row],
659
+ )
660
+
661
+ # Reset logic - different behavior based on rewrite availability
662
+ if self.rewrite_available:
663
+ self.text_input.change(
664
+ fn=lambda: (
665
+ gr.update(visible=False),
666
+ gr.update(interactive=False),
667
+ "Please click the [🔄 Rewrite Text] button to rewrite the text first",
668
+ ),
669
+ outputs=[self.rewritten_text, self.generate_btn, self.status_output],
670
+ )
671
+ else:
672
+ # When rewrite is not available, enable generate button directly when text is entered
673
+ self.text_input.change(
674
+ fn=lambda text: (
675
+ gr.update(visible=False),
676
+ gr.update(interactive=bool(text.strip())),
677
+ (
678
+ "Ready to generate! Click [🚀 Generate Motion] to start."
679
+ if text.strip()
680
+ else "Enter your text and click [🚀 Generate Motion] directly."
681
+ ),
682
+ ),
683
+ inputs=[self.text_input],
684
+ outputs=[self.rewritten_text, self.generate_btn, self.status_output],
685
+ )
686
+ # Only bind rewritten_text change when rewrite is available
687
+ if self.rewrite_available:
688
+ self.rewritten_text.change(
689
+ fn=lambda text: (
690
+ gr.update(interactive=bool(text.strip())),
691
+ (
692
+ "Rewritten text has been modified, you can click [🚀 Generate Motion]"
693
+ if text.strip()
694
+ else "Rewritten text cannot be empty, please enter valid text"
695
+ ),
696
+ ),
697
+ inputs=[self.rewritten_text],
698
+ outputs=[self.generate_btn, self.status_output],
699
+ )
700
+
701
+
702
+ if __name__ == "__main__":
703
+ # parser = argparse.ArgumentParser(description="HY-Motion-1.0 Text-to-Motion Gradio")
704
+ # parser.add_argument("--model_path", type=str, required=True, help="Configuration file path")
705
+ # parser.add_argument(
706
+ # "--device_ids", type=str, default=None, help="GPU device ID list, separated by commas, e.g.: 0,1,2,3"
707
+ # )
708
+ # parser.add_argument(
709
+ # "--prompt_engineering_host",
710
+ # type=str,
711
+ # default=None,
712
+ # help="Prompt engineering host address, for text rewriting and duration estimation",
713
+ # )
714
+ # parser.add_argument("--output_dir", type=str, default="output/gradio", help="Output directory")
715
+ # parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server name")
716
+ # parser.add_argument("--port", type=int, default=8080, help="Server port")
717
+ # parser.add_argument("--disable_flask_server", action="store_true")
718
+ # parser.add_argument(
719
+ # "--disable_rewrite",
720
+ # action="store_true",
721
+ # help="Disable text rewriting and duration estimation, use input text and duration directly",
722
+ # )
723
+ # args = parser.parse_args()
724
+
725
+ final_model_path = './configs/base'
726
+
727
+ class Args:
728
+ model_path = final_model_path
729
+ output_dir = "output/gradio"
730
+ prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
731
+ disable_rewrite = False
732
+
733
+ args = Args()
734
+
735
+ # Check required files:
736
+ cfg = osp.join(args.model_path, "config.yml")
737
+ ckpt = osp.join(args.model_path, "latest.ckpt")
738
+ if not osp.exists(cfg):
739
+ raise FileNotFoundError(f">>> Configuration file not found: {cfg}")
740
+
741
+ # Check checkpoint file - skip loading if not exists
742
+ skip_model_loading = False
743
+ if not os.path.exists(ckpt):
744
+ print(f">>> [WARNING] Checkpoint file not found: {ckpt}")
745
+ print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.")
746
+ skip_model_loading = True
747
+
748
+ # Initialize runtime
749
+ print(">>> Initializing T2MRuntime...")
750
+ runtime = T2MRuntime(
751
+ config_path=cfg,
752
+ ckpt_name=ckpt,
753
+ device_ids=None,
754
+ prompt_engineering_host=args.prompt_engineering_host,
755
+ skip_model_loading=skip_model_loading,
756
+ )
757
+
758
+ # Create output directory
759
+ os.makedirs(args.output_dir, exist_ok=True)
760
+
761
+ ui = T2MGradioUI(runtime=runtime, args=args)
762
+ demo = ui.build_ui()
763
+
764
+ demo.launch()
hymotion/network/attention.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
11
+ except ImportError:
12
+ flash_attn = None
13
+ flash_attn_varlen_func = None
14
+ _flash_attn_forward = None
15
+
16
+
17
+ MEMORY_LAYOUT = {
18
+ "flash": (lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x),
19
+ "torch": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
20
+ "vanilla": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)),
21
+ }
22
+
23
+
24
+ def attention(
25
+ q: Tensor,
26
+ k: Tensor,
27
+ v: Tensor,
28
+ mode: str = "flash",
29
+ drop_rate: float = 0.0,
30
+ attn_mask: Optional[Tensor] = None,
31
+ causal: bool = False,
32
+ cu_seqlens_q: Optional[Tensor] = None,
33
+ cu_seqlens_kv: Optional[Tensor] = None,
34
+ max_seqlen_q: Optional[int] = None,
35
+ max_seqlen_kv: Optional[int] = None,
36
+ batch_size: int = 1,
37
+ training: bool = True,
38
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
39
+ """Perform QKV self attention.
40
+
41
+ Args:
42
+ q (Tensor): Query tensor with shape [b, s, h, d], where h is the number of heads.
43
+ k (Tensor): Key tensor with shape [b, s1, h, d]
44
+ v (Tensor): Value tensor with shape [b, s1, h, d]
45
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
46
+ drop_rate (float): Dropout rate in attention map. (default: 0)
47
+ attn_mask (Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, h, s, s1] (torch or vanilla).
48
+ (default: None)
49
+ causal (bool): Whether to use causal attention. (default: False)
50
+ cu_seqlens_q (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
51
+ used to index into q.
52
+ cu_seqlens_kv (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
53
+ used to index into kv.
54
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
55
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
56
+
57
+ Returns:
58
+ Tensor: Output tensor after self attention with shape [b, s, hd]
59
+ """
60
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
61
+ q = pre_attn_layout(q)
62
+ k = pre_attn_layout(k)
63
+ v = pre_attn_layout(v)
64
+
65
+ if mode == "torch":
66
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
67
+ attn_mask = attn_mask.to(q.dtype)
68
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
69
+ elif mode == "flash":
70
+ assert flash_attn_varlen_func is not None, "flash_attn is not installed or not supported"
71
+ x = flash_attn_varlen_func(
72
+ q,
73
+ k,
74
+ v,
75
+ cu_seqlens_q,
76
+ cu_seqlens_kv,
77
+ max_seqlen_q,
78
+ max_seqlen_kv,
79
+ )
80
+ # x with shape [(bxs), a, d]
81
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
82
+ elif mode == "vanilla":
83
+ scale_factor = 1.0 / math.sqrt(q.size(-1))
84
+ b, a, s_q, _ = q.shape
85
+ s_k = k.size(2)
86
+ attn_bias = torch.zeros(b, a, s_q, s_k, dtype=q.dtype, device=q.device)
87
+ if causal:
88
+ # Only applied to self attention
89
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
90
+ temp_mask = torch.ones(b, a, s_q, s_q, dtype=torch.bool, device=q.device).tril(diagonal=0)
91
+ attn_bias.masked_fill_(~temp_mask, float("-inf"))
92
+ attn_bias = attn_bias.to(q.dtype)
93
+ if attn_mask is not None:
94
+ if attn_mask.dtype == torch.bool:
95
+ attn_bias.masked_fill_(~attn_mask, float("-inf"))
96
+ else:
97
+ attn_bias = attn_bias + attn_mask
98
+
99
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
100
+ attn = attn + attn_bias
101
+ attn = attn.softmax(dim=-1)
102
+ attn = torch.dropout(attn, p=drop_rate, train=training)
103
+ x = attn @ v
104
+ else:
105
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
106
+
107
+ x = post_attn_layout(x)
108
+ b, s, h, d = x.shape
109
+ out = x.reshape(b, s, -1)
110
+ return out
hymotion/network/bricks.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor
6
+
7
+
8
+ def get_activation_layer(act_type: str) -> Callable[[], nn.Module]:
9
+ if act_type == "gelu":
10
+ return lambda: nn.GELU()
11
+ elif act_type == "gelu_tanh":
12
+ return lambda: nn.GELU(approximate="tanh")
13
+ elif act_type == "relu":
14
+ return nn.ReLU
15
+ elif act_type == "silu":
16
+ return nn.SiLU
17
+ else:
18
+ raise ValueError(f"Unknown activation type: {act_type}")
19
+
20
+
21
+ def get_norm_layer(norm_type: Optional[str]):
22
+ if norm_type == "layer":
23
+ return nn.LayerNorm
24
+ elif norm_type == "rms":
25
+ return RMSNorm
26
+ elif norm_type == "none" or norm_type is None:
27
+ return nn.Identity
28
+ else:
29
+ raise ValueError(f"Unknown norm type: {norm_type}")
30
+
31
+
32
+ class RMSNorm(nn.Module):
33
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.eps = eps
36
+ if elementwise_affine:
37
+ self.weight = nn.Parameter(torch.ones(dim))
38
+
39
+ def _norm(self, x: Tensor) -> Tensor:
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ output = self._norm(x.float()).type_as(x)
44
+ if hasattr(self, "weight"):
45
+ output = output * self.weight
46
+ return output
hymotion/network/encoders.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from ..utils.misc import to_2tuple
9
+ from .bricks import get_activation_layer, get_norm_layer
10
+ from .modulate_layers import ModulateDiT, modulate
11
+
12
+
13
+ class MLP(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim: int,
17
+ feat_dim: int,
18
+ out_dim: Optional[int] = None,
19
+ act_type: str = "gelu",
20
+ norm_type: Optional[str] = None,
21
+ bias: bool = True,
22
+ drop: float = 0.0,
23
+ use_conv: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+ out_dim = out_dim or in_dim
27
+ feat_dim = feat_dim or in_dim
28
+ bias = to_2tuple(bias)
29
+ drop_probs = to_2tuple(drop)
30
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
31
+
32
+ self.fc1 = linear_layer(in_dim, feat_dim, bias=bias[0] if isinstance(bias, (list, tuple)) else bias)
33
+ self.act = get_activation_layer(act_type)()
34
+ self.drop1 = nn.Dropout(drop_probs[0] if isinstance(drop_probs, (list, tuple)) else drop_probs)
35
+ self.norm = get_norm_layer(norm_type)(feat_dim) if norm_type else nn.Identity()
36
+ self.fc2 = linear_layer(feat_dim, out_dim, bias=bias[1] if isinstance(bias, (list, tuple)) else bias)
37
+ self.drop2 = nn.Dropout(drop_probs[1])
38
+
39
+ def forward(self, x: Tensor) -> Tensor:
40
+ x = self.fc1(x)
41
+ x = self.act(x)
42
+ x = self.drop1(x)
43
+ x = self.norm(x)
44
+ x = self.fc2(x)
45
+ x = self.drop2(x)
46
+ return x
47
+
48
+
49
+ class MLPEncoder(nn.Module):
50
+ def __init__(self, in_dim: int, feat_dim: int, num_layers: int, act_type: str = "silu") -> None:
51
+ super(MLPEncoder, self).__init__()
52
+ self.in_dim = in_dim
53
+ self.feat_dim = feat_dim
54
+ linears = []
55
+ linears.append(nn.Linear(in_features=in_dim, out_features=self.feat_dim))
56
+ for i in range(num_layers - 1):
57
+ linears.append(get_activation_layer(act_type)())
58
+ linears.append(nn.Linear(self.feat_dim, self.feat_dim))
59
+ self.linears = nn.Sequential(*linears)
60
+
61
+ def forward(self, x: Tensor) -> Tensor:
62
+ return self.linears(x)
63
+
64
+
65
+ class FinalLayer(nn.Module):
66
+ def __init__(self, feat_dim: int, out_dim: int, act_type: str = "gelu", zero_init=False, **kwargs):
67
+ super().__init__()
68
+
69
+ self.norm_final = nn.LayerNorm(feat_dim, elementwise_affine=False, eps=1e-6)
70
+ self.adaLN_modulation = ModulateDiT(feat_dim, factor=2, act_type=act_type)
71
+ self.linear = nn.Linear(feat_dim, out_dim, bias=True)
72
+ if zero_init:
73
+ nn.init.zeros_(self.linear.weight)
74
+ nn.init.zeros_(self.linear.bias)
75
+
76
+ def forward(self, x: Tensor, adapter: Tensor) -> Tensor:
77
+ shift, scale = self.adaLN_modulation(adapter).chunk(2, dim=-1)
78
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
79
+ x = self.linear(x)
80
+ return x
81
+
82
+
83
+ class TimestepEmbeddingEncoder(nn.Module):
84
+ def __init__(
85
+ self,
86
+ embedding_dim: int,
87
+ feat_dim: int,
88
+ act_type: str = "silu",
89
+ time_factor: float = 1.0,
90
+ ) -> None:
91
+ super(TimestepEmbeddingEncoder, self).__init__()
92
+
93
+ self.embedding_dim = embedding_dim
94
+ self.feat_dim = feat_dim
95
+ self.time_factor = time_factor
96
+ blocks = [
97
+ nn.Linear(embedding_dim, self.feat_dim),
98
+ get_activation_layer(act_type)(),
99
+ nn.Linear(self.feat_dim, self.feat_dim),
100
+ ]
101
+ self.blocks = nn.Sequential(*blocks)
102
+
103
+ def forward(self, t: Tensor) -> Tensor:
104
+ x = self.blocks(self.sinusodial_embedding(t, self.embedding_dim, time_factor=self.time_factor)).unsqueeze(1)
105
+ return x
106
+
107
+ @staticmethod
108
+ def sinusodial_embedding(
109
+ timesteps: Tensor, embedding_dim: int, temperature: float = 10000.0, time_factor: float = 1.0
110
+ ) -> Tensor:
111
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
112
+ timesteps = timesteps * time_factor
113
+ half = embedding_dim // 2
114
+ freqs = torch.exp(
115
+ -torch.log(torch.tensor(temperature)) * torch.arange(start=0, end=half, dtype=torch.float) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
119
+ if embedding_dim % 2:
120
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
hymotion/network/hymotion_mmdit.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from ..utils.loaders import load_object
12
+ from ..utils.type_converter import get_module_device
13
+ from .attention import attention
14
+ from .bricks import get_activation_layer, get_norm_layer
15
+ from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
16
+ from .modulate_layers import ModulateDiT, apply_gate, modulate
17
+ from .positional_encoding import RotaryEmbedding
18
+
19
+
20
+ class MMBaseBlock(nn.Module):
21
+ def __init__(
22
+ self,
23
+ feat_dim: int,
24
+ num_heads: int,
25
+ mlp_ratio: float,
26
+ dropout: float,
27
+ positional_encoding_cfg: dict,
28
+ apply_rope_to_single_branch: bool,
29
+ ):
30
+ super().__init__()
31
+ self.feat_dim = feat_dim
32
+ self.num_heads = num_heads
33
+ self.mlp_ratio = mlp_ratio
34
+ self.dropout = dropout
35
+
36
+ assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
37
+ self.head_dim = self.feat_dim // num_heads
38
+
39
+ self.mlp_hidden_dim = int(self.feat_dim * mlp_ratio)
40
+
41
+ self._positional_encoding_cfg = positional_encoding_cfg.copy()
42
+ self.rotary_emb = RotaryEmbedding(num_feats=self.head_dim, **self._positional_encoding_cfg)
43
+ self.apply_rope_to_single_branch = apply_rope_to_single_branch
44
+
45
+
46
+ class MMDoubleStreamBlock(MMBaseBlock):
47
+ def __init__(
48
+ self,
49
+ feat_dim: int,
50
+ num_heads: int,
51
+ mlp_ratio: float,
52
+ dropout: float,
53
+ mlp_act_type: str,
54
+ qk_norm_type: Optional[str] = None,
55
+ qkv_bias: bool = False,
56
+ positional_encoding_cfg: dict = {
57
+ "max_seq_len": 5000,
58
+ "use_real": True,
59
+ },
60
+ apply_rope_to_single_branch: bool = True,
61
+ ):
62
+ super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
63
+
64
+ self.motion_mod = ModulateDiT(
65
+ self.feat_dim,
66
+ factor=6,
67
+ act_type="silu",
68
+ )
69
+ self.motion_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
70
+
71
+ motion_qkv_out_dim = self.feat_dim * 3
72
+ self.motion_qkv = nn.Linear(self.feat_dim, motion_qkv_out_dim, bias=qkv_bias)
73
+
74
+ self.motion_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
75
+ self.motion_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
76
+ self.motion_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
77
+ self.motion_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
78
+ self.motion_mlp = MLP(
79
+ self.feat_dim,
80
+ self.mlp_hidden_dim,
81
+ act_type=mlp_act_type,
82
+ bias=True,
83
+ )
84
+
85
+ self.text_mod = ModulateDiT(
86
+ self.feat_dim,
87
+ factor=6,
88
+ act_type="silu",
89
+ )
90
+ self.text_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
91
+
92
+ text_qkv_out_dim = self.feat_dim * 3
93
+ self.text_qkv = nn.Linear(self.feat_dim, text_qkv_out_dim, bias=qkv_bias)
94
+
95
+ self.text_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
96
+ self.text_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
97
+ self.text_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias)
98
+ self.text_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
99
+ self.text_mlp = MLP(
100
+ self.feat_dim,
101
+ self.mlp_hidden_dim,
102
+ act_type=mlp_act_type,
103
+ bias=True,
104
+ )
105
+
106
+ def forward(
107
+ self,
108
+ motion_feat: Tensor,
109
+ text_feat: Tensor,
110
+ adapter: Tensor,
111
+ attn_mask: Optional[Tensor] = None,
112
+ ) -> Tuple[Tensor, Tensor]:
113
+ (
114
+ motion_shift_msa,
115
+ motion_scale_msa,
116
+ motion_gate_msa,
117
+ motion_shift_mlp,
118
+ motion_scale_mlp,
119
+ motion_gate_mlp,
120
+ ) = self.motion_mod(adapter).chunk(6, dim=-1)
121
+ (
122
+ text_shift_msa,
123
+ text_scale_msa,
124
+ text_gate_msa,
125
+ text_shift_mlp,
126
+ text_scale_mlp,
127
+ text_gate_mlp,
128
+ ) = self.text_mod(
129
+ adapter
130
+ ).chunk(6, dim=-1)
131
+
132
+ motion_modulated = self.motion_norm1(motion_feat)
133
+ motion_modulated = modulate(motion_modulated, shift=motion_shift_msa, scale=motion_scale_msa)
134
+ motion_qkv = self.motion_qkv(motion_modulated)
135
+
136
+ motion_q, motion_k, motion_v = rearrange(motion_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
137
+ motion_q = self.motion_q_norm(motion_q).to(motion_v)
138
+ motion_k = self.motion_k_norm(motion_k).to(motion_v)
139
+
140
+ if self.apply_rope_to_single_branch:
141
+ # NOTE: we don't apply RoPE to text_branch_two here
142
+ motion_q, motion_k = self.rotary_emb.apply_rotary_emb(motion_q, motion_k)
143
+
144
+ text_modulated = self.text_norm1(text_feat)
145
+ text_modulated = modulate(text_modulated, shift=text_shift_msa, scale=text_scale_msa)
146
+ text_qkv = self.text_qkv(text_modulated)
147
+
148
+ text_q, text_k, text_v = rearrange(
149
+ text_qkv,
150
+ "B L (K H D) -> K B L H D",
151
+ K=3,
152
+ H=self.num_heads,
153
+ )
154
+ text_q = self.text_q_norm(text_q).to(text_v)
155
+ text_k = self.text_k_norm(text_k).to(text_v)
156
+
157
+ q = torch.cat((motion_q, text_q), dim=1)
158
+ k = torch.cat((motion_k, text_k), dim=1)
159
+ v = torch.cat((motion_v, text_v), dim=1)
160
+
161
+ if not self.apply_rope_to_single_branch:
162
+ q, k = self.rotary_emb.apply_rotary_emb(q, k)
163
+
164
+ bsz, total_len, _, _ = q.shape
165
+ motion_len = motion_feat.shape[1]
166
+ text_len = text_feat.shape[1]
167
+ dropout_p = 0.0 if not self.training else self.dropout
168
+
169
+ attn_output = attention(
170
+ q,
171
+ k,
172
+ v,
173
+ mode="torch",
174
+ drop_rate=dropout_p,
175
+ attn_mask=attn_mask,
176
+ causal=False,
177
+ cu_seqlens_q=None,
178
+ cu_seqlens_kv=None,
179
+ max_seqlen_q=None,
180
+ max_seqlen_kv=None,
181
+ batch_size=bsz,
182
+ training=self.training,
183
+ )
184
+
185
+ motion_attn_output, text_attn_output = (
186
+ attn_output[:, :motion_len, ...],
187
+ attn_output[:, motion_len:, ...],
188
+ )
189
+
190
+ motion_feat = motion_feat + apply_gate(self.motion_out_proj(motion_attn_output), gate=motion_gate_msa)
191
+ motion_feat = motion_feat + apply_gate(
192
+ self.motion_mlp(
193
+ modulate(
194
+ self.motion_norm2(motion_feat),
195
+ shift=motion_shift_mlp,
196
+ scale=motion_scale_mlp,
197
+ )
198
+ ),
199
+ gate=motion_gate_mlp,
200
+ )
201
+
202
+ text_feat = text_feat + apply_gate(self.text_out_proj(text_attn_output), gate=text_gate_msa)
203
+ text_feat = text_feat + apply_gate(
204
+ self.text_mlp(
205
+ modulate(
206
+ self.text_norm2(text_feat),
207
+ shift=text_shift_mlp,
208
+ scale=text_scale_mlp,
209
+ )
210
+ ),
211
+ gate=text_gate_mlp,
212
+ )
213
+
214
+ return motion_feat, text_feat
215
+
216
+
217
+ class MMSingleStreamBlock(MMBaseBlock):
218
+ def __init__(
219
+ self,
220
+ feat_dim: int,
221
+ num_heads: int,
222
+ mlp_ratio: float,
223
+ dropout: float,
224
+ mlp_act_type: str,
225
+ qk_norm_type: Optional[str] = None,
226
+ qkv_bias: bool = False,
227
+ positional_encoding_cfg: dict = {
228
+ "max_seq_len": 5000,
229
+ "use_real": True,
230
+ },
231
+ apply_rope_to_single_branch: bool = True,
232
+ ):
233
+ super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch)
234
+
235
+ self.modulation = ModulateDiT(self.feat_dim, factor=3, act_type="silu")
236
+ self.norm = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6)
237
+
238
+ # qkv and mlp_in
239
+ qkv_factor = 3
240
+ self.linear1 = nn.Linear(self.feat_dim, self.feat_dim * qkv_factor + self.mlp_hidden_dim, bias=qkv_bias)
241
+ # proj and mlp_out
242
+ self.linear2 = nn.Linear(self.feat_dim + self.mlp_hidden_dim, self.feat_dim, bias=qkv_bias)
243
+
244
+ self.q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
245
+ self.k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
246
+
247
+ self.mlp_act = get_activation_layer(mlp_act_type)()
248
+
249
+ def forward(
250
+ self,
251
+ x: Tensor,
252
+ split_len: int,
253
+ adapter: Tensor,
254
+ attn_mask: Optional[Tensor] = None,
255
+ ) -> Tensor:
256
+ (
257
+ shift_msa,
258
+ scale_msa,
259
+ gate_msa,
260
+ ) = self.modulation(
261
+ adapter
262
+ ).chunk(3, dim=-1)
263
+ x_modulated = modulate(self.norm(x), shift_msa, scale_msa)
264
+
265
+ qkv, mlp_hidden = torch.split(self.linear1(x_modulated), [3 * self.feat_dim, self.mlp_hidden_dim], dim=-1)
266
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
267
+
268
+ q = self.q_norm(q).to(v)
269
+ k = self.k_norm(k).to(v)
270
+
271
+ q1, q2 = q[:, :split_len, ...], q[:, split_len:, ...]
272
+ k1, k2 = k[:, :split_len, ...], k[:, split_len:, ...]
273
+ # apply rotary position embedding
274
+ if self.apply_rope_to_single_branch:
275
+ q1, k1 = self.rotary_emb.apply_rotary_emb(q1, k1)
276
+ q = torch.cat((q1, q2), dim=1)
277
+ k = torch.cat((k1, k2), dim=1)
278
+ if not self.apply_rope_to_single_branch:
279
+ q, k = self.rotary_emb.apply_rotary_emb(q, k)
280
+
281
+ bsz, total_len = x_modulated.shape[:2]
282
+ dropout_p = 0.0 if not self.training else self.dropout
283
+
284
+ attn_output = attention(
285
+ q,
286
+ k,
287
+ v,
288
+ mode="torch",
289
+ drop_rate=dropout_p,
290
+ attn_mask=attn_mask,
291
+ causal=False,
292
+ cu_seqlens_q=None,
293
+ cu_seqlens_kv=None,
294
+ max_seqlen_q=None,
295
+ max_seqlen_kv=None,
296
+ batch_size=bsz,
297
+ training=self.training,
298
+ )
299
+ output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp_hidden)), 2))
300
+
301
+ return x + apply_gate(output, gate=gate_msa)
302
+
303
+
304
+ class HunyuanMotionMMDiT(nn.Module):
305
+ def __init__(
306
+ self,
307
+ input_dim: int,
308
+ feat_dim: int,
309
+ output_dim: Optional[int] = None,
310
+ ctxt_input_dim: int = 4096,
311
+ vtxt_input_dim: int = 256,
312
+ text_refiner_module: str = "hymotion/network/token_refiner.SingleTokenRefiner",
313
+ text_refiner_cfg: dict = {
314
+ "num_layers": 2,
315
+ },
316
+ num_layers: int = 12,
317
+ num_heads: int = 16,
318
+ mlp_ratio: float = 4.0,
319
+ mlp_act_type: str = "gelu_tanh",
320
+ norm_type: str = "layer",
321
+ qk_norm_type: str = "rms",
322
+ qkv_bias: bool = True,
323
+ dropout: float = 0.0,
324
+ final_layer_module: str = "hymotion/network/encoders.FinalLayer",
325
+ final_layer_cfg: dict = {
326
+ "act_type": "silu",
327
+ },
328
+ mask_mode: Optional[str] = None,
329
+ apply_rope_to_single_branch: bool = True,
330
+ insert_start_token: bool = False,
331
+ with_long_skip_connection: bool = False,
332
+ time_factor: float = 1.0,
333
+ narrowband_length: float = 2.0,
334
+ **kwargs,
335
+ ):
336
+ super().__init__()
337
+ self.motion_input_dim = input_dim
338
+ self.ctxt_input_dim = ctxt_input_dim
339
+ self.vtxt_input_dim = vtxt_input_dim
340
+ self.feat_dim = feat_dim
341
+ self.output_dim = output_dim or input_dim
342
+ self.mask_mode = mask_mode
343
+ self.insert_start_token = insert_start_token
344
+ self.time_factor = time_factor
345
+ self.narrowband_length = narrowband_length * 30.0
346
+ if self.insert_start_token:
347
+ self.start_token = nn.Parameter(torch.randn(1, feat_dim))
348
+ self.with_long_skip_connection = with_long_skip_connection
349
+ if self.with_long_skip_connection:
350
+ from .encoders import FinalLayer
351
+
352
+ self.long_skip_net = FinalLayer(feat_dim=feat_dim, out_dim=feat_dim, act_type="silu")
353
+
354
+ self.input_encoder = nn.Linear(in_features=input_dim, out_features=feat_dim)
355
+ self.ctxt_encoder = nn.Linear(in_features=ctxt_input_dim, out_features=feat_dim)
356
+ self.vtxt_encoder = MLPEncoder(in_dim=vtxt_input_dim, feat_dim=feat_dim, num_layers=2, act_type="silu")
357
+ self.timestep_encoder = TimestepEmbeddingEncoder(
358
+ embedding_dim=feat_dim,
359
+ feat_dim=feat_dim,
360
+ time_factor=time_factor,
361
+ )
362
+
363
+ if text_refiner_module != "" and text_refiner_module is not None:
364
+ text_refiner_cfg.update(input_dim=feat_dim, feat_dim=feat_dim, num_heads=num_heads)
365
+ self._text_refiner_cfg = text_refiner_cfg.copy()
366
+ self.text_refiner = load_object(text_refiner_module, text_refiner_cfg)
367
+
368
+ self.num_layers = num_layers
369
+ assert num_layers % 3 == 0, f"num_layers must be divisible by 3, but got {num_layers}"
370
+ self.mm_double_blocks_layers = int(num_layers // 3)
371
+ self.mm_single_blocks_layers = int(num_layers - num_layers // 3)
372
+
373
+ self.double_blocks = nn.ModuleList(
374
+ [
375
+ MMDoubleStreamBlock(
376
+ feat_dim=feat_dim,
377
+ num_heads=num_heads,
378
+ mlp_ratio=mlp_ratio,
379
+ dropout=dropout,
380
+ mlp_act_type=mlp_act_type,
381
+ qk_norm_type=qk_norm_type,
382
+ qkv_bias=qkv_bias,
383
+ apply_rope_to_single_branch=apply_rope_to_single_branch,
384
+ )
385
+ for _ in range(self.mm_double_blocks_layers)
386
+ ]
387
+ )
388
+
389
+ self.single_blocks = nn.ModuleList(
390
+ [
391
+ MMSingleStreamBlock(
392
+ feat_dim=feat_dim,
393
+ num_heads=num_heads,
394
+ mlp_ratio=mlp_ratio,
395
+ dropout=dropout,
396
+ mlp_act_type=mlp_act_type,
397
+ qk_norm_type=qk_norm_type,
398
+ qkv_bias=qkv_bias,
399
+ apply_rope_to_single_branch=apply_rope_to_single_branch,
400
+ )
401
+ for _ in range(self.mm_single_blocks_layers)
402
+ ]
403
+ )
404
+
405
+ final_layer_cfg.update(feat_dim=feat_dim, out_dim=self.output_dim)
406
+ self._final_layer_cfg = final_layer_cfg.copy()
407
+ self.final_layer = load_object(final_layer_module, final_layer_cfg)
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ ctxt_input: Tensor,
413
+ vtxt_input: Tensor,
414
+ timesteps: Tensor,
415
+ x_mask_temporal: Tensor,
416
+ ctxt_mask_temporal: Tensor,
417
+ **kwargs,
418
+ ) -> Tensor:
419
+ device = get_module_device(self)
420
+
421
+ motion_feat = self.input_encoder(x)
422
+ if self.with_long_skip_connection:
423
+ origin_feat = motion_feat
424
+ if self.insert_start_token:
425
+ # (B, 1, D) + (B, L, D) -> (B, L+1, D)
426
+ start_token = self.start_token[None].repeat(motion_feat.shape[0], 1, 1)
427
+ motion_feat = torch.cat((start_token, motion_feat), dim=1)
428
+ x_mask_temporal = torch.cat(
429
+ [
430
+ torch.ones_like(x_mask_temporal[:, :1], dtype=torch.bool),
431
+ x_mask_temporal,
432
+ ],
433
+ dim=1,
434
+ )
435
+
436
+ timestep_feat = self.timestep_encoder(timesteps)
437
+ vtxt_feat = self.vtxt_encoder(vtxt_input.float())
438
+ adapter = timestep_feat + vtxt_feat
439
+
440
+ motion_key_padding_mask = self._canonical_mask(x_mask_temporal).to(device)
441
+ ctxt_key_padding_mask = self._canonical_mask(ctxt_mask_temporal).to(device)
442
+ seq_key_padding_mask = torch.cat((motion_key_padding_mask, ctxt_key_padding_mask), dim=1)
443
+ if self.mask_mode is None:
444
+ seq_mask = None
445
+ elif self.mask_mode == "causal":
446
+ motion_len = motion_feat.shape[1]
447
+ seq_mask = torch.triu(
448
+ torch.full((motion_len, motion_len), float("-inf"), device=device),
449
+ diagonal=1,
450
+ )
451
+ elif self.mask_mode == "narrowband":
452
+ window = int(round(self.narrowband_length))
453
+ motion_len = motion_feat.shape[1]
454
+ idx = torch.arange(motion_len, device=device)
455
+ dist = (idx[None, :] - idx[:, None]).abs()
456
+ band = dist <= window
457
+ seq_mask = torch.full((motion_len, motion_len), float("-inf"), device=device)
458
+ seq_mask = seq_mask.masked_fill(band, 0.0)
459
+ else:
460
+ raise ValueError(f"Unsupported mask mode: {self.mask_mode}")
461
+
462
+ ctxt_feat = self.ctxt_encoder(ctxt_input.float())
463
+ if hasattr(self, "text_refiner"):
464
+ ctxt_feat = self.text_refiner(x=ctxt_feat, t=timesteps, mask=(ctxt_key_padding_mask == 0).to(device))
465
+
466
+ # precompute shared attention masks (broadcastable over heads)
467
+ bsz = x.shape[0]
468
+ motion_len = motion_feat.shape[1]
469
+ text_len = ctxt_feat.shape[1]
470
+ total_len = motion_len + text_len
471
+ mask_dtype = motion_feat.dtype
472
+ attn_mask_double = self._build_dmm_attn_mask_shared(
473
+ bsz=bsz,
474
+ motion_len=motion_len,
475
+ text_len=text_len,
476
+ dtype=mask_dtype,
477
+ key_padding_mask=seq_key_padding_mask,
478
+ attn_mask=seq_mask,
479
+ device=device,
480
+ )
481
+ for i_layer, mod in enumerate(self.double_blocks):
482
+ motion_feat, ctxt_feat = mod(
483
+ motion_feat=motion_feat,
484
+ text_feat=ctxt_feat,
485
+ adapter=adapter,
486
+ attn_mask=attn_mask_double,
487
+ )
488
+
489
+ # precompute shared attention masks for single stream blocks too
490
+ split_len = motion_feat.shape[1]
491
+ x = torch.cat((motion_feat, ctxt_feat), 1)
492
+ attn_mask_single = self._build_smm_attn_mask_shared(
493
+ bsz=bsz,
494
+ split_len=split_len,
495
+ total_len=total_len,
496
+ dtype=mask_dtype,
497
+ key_padding_mask=seq_key_padding_mask,
498
+ attn_mask=seq_mask,
499
+ device=device,
500
+ )
501
+ for i_layer, mod in enumerate(self.single_blocks):
502
+ x = mod(
503
+ x=x,
504
+ split_len=split_len,
505
+ adapter=adapter,
506
+ attn_mask=attn_mask_single,
507
+ )
508
+
509
+ x = x[:, :split_len, ...]
510
+ if self.insert_start_token:
511
+ x = x[:, 1:, ...]
512
+
513
+ if self.with_long_skip_connection:
514
+ # long skip only consider timestep_feat
515
+ x = self.long_skip_net(origin_feat, timestep_feat) + x
516
+
517
+ predicted_res = self.final_layer(x, adapter)
518
+ return predicted_res
519
+
520
+ @staticmethod
521
+ def _canonical_mask(input_mask: Tensor) -> Tensor:
522
+ if input_mask.ndim == 1:
523
+ input_mask = input_mask.unsqueeze(1)
524
+ key_padding_mask = torch.where(
525
+ input_mask,
526
+ torch.zeros_like(input_mask, dtype=torch.float),
527
+ torch.full_like(input_mask, float("-inf"), dtype=torch.float),
528
+ )
529
+ return key_padding_mask
530
+
531
+ def _build_dmm_attn_mask_shared(
532
+ self,
533
+ bsz: int,
534
+ motion_len: int,
535
+ text_len: int,
536
+ dtype: torch.dtype,
537
+ key_padding_mask: Optional[Tensor],
538
+ attn_mask: Optional[Tensor],
539
+ device: torch.device,
540
+ ) -> Tensor:
541
+ """
542
+ NOTE:
543
+ motion_k text_k
544
+ motion_q [M→M] [M→T]
545
+ text_q [T→M] [T→T]
546
+ only [M→M] contains given mask
547
+ """
548
+ total_len = motion_len + text_len
549
+ base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
550
+ if attn_mask is not None:
551
+ if attn_mask.dim() != 2 or attn_mask.shape != (motion_len, motion_len):
552
+ raise RuntimeError(
553
+ f"attn_mask should be 2D with shape {(motion_len, motion_len)}, got {attn_mask.shape}"
554
+ )
555
+ base[:, :, :motion_len, :motion_len] += attn_mask.view(1, 1, motion_len, motion_len)
556
+ if key_padding_mask is not None:
557
+ mask_total_len = key_padding_mask.shape[1]
558
+ if mask_total_len == motion_len:
559
+ pad = torch.zeros((bsz, text_len), dtype=key_padding_mask.dtype, device=device)
560
+ key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
561
+ base = base + key_padding_mask.view(bsz, 1, 1, total_len)
562
+ # disable T→M
563
+ base[:, :, motion_len:, :motion_len] = float("-inf")
564
+ return base
565
+
566
+ def _build_smm_attn_mask_shared(
567
+ self,
568
+ bsz: int,
569
+ split_len: int,
570
+ total_len: int,
571
+ dtype: torch.dtype,
572
+ key_padding_mask: Optional[Tensor],
573
+ attn_mask: Optional[Tensor],
574
+ device: torch.device,
575
+ ) -> Tensor:
576
+ """
577
+ NOTE:
578
+ motion_k text_k
579
+ motion_q [M→M] [M→T]
580
+ text_q [T→M] [T→T]
581
+ only [M→M] contains given mask
582
+ """
583
+ base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device)
584
+ if attn_mask is not None:
585
+ if attn_mask.dim() != 2 or attn_mask.shape != (split_len, split_len):
586
+ raise RuntimeError(f"attn_mask should be 2D with shape {(split_len, split_len)}, got {attn_mask.shape}")
587
+ base[:, :, :split_len, :split_len] += attn_mask.view(1, 1, split_len, split_len)
588
+ if key_padding_mask is not None:
589
+ mask_total_len = key_padding_mask.shape[1]
590
+ if mask_total_len == split_len:
591
+ pad = torch.zeros(
592
+ (bsz, total_len - split_len),
593
+ dtype=key_padding_mask.dtype,
594
+ device=device,
595
+ )
596
+ key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1)
597
+ base = base + key_padding_mask.view(bsz, 1, 1, total_len)
598
+ # disable T→M
599
+ base[:, :, split_len:, :split_len] = float("-inf")
600
+ return base
601
+
602
+
603
+ if __name__ == "__main__":
604
+ # python -m hymotion.network.hymotion_mmdit
605
+
606
+ from configs._base_.model_network_base import MOTION_MODEL_CONFIG # pyright: ignore
607
+
608
+ network_module_cfg = MOTION_MODEL_CONFIG["1.04B_narrowband"]["network_module_args"]
609
+ network_module_cfg = dict(network_module_cfg) # convert to normal dict
610
+
611
+ bsz, seq_len, text_seq_len, input_dim = 1, 360, 128, 201
612
+ network_module_cfg["input_dim"] = input_dim
613
+ MMDiT = HunyuanMotionMMDiT(**network_module_cfg)
614
+
615
+ x = torch.randn(bsz, seq_len, input_dim)
616
+ ctxt_condition = torch.randn(bsz, text_seq_len, 4096)
617
+ vtxt_condition = torch.randn(bsz, 1, 768)
618
+ timesteps = torch.randint(0, 1000, (bsz,))
619
+ length = torch.arange(seq_len).unsqueeze(0).repeat(bsz, 1)
620
+ ctxt_length = torch.arange(text_seq_len).unsqueeze(0).repeat(bsz, 1)
621
+ x_mask_temporal = length < 100
622
+ ctxt_mask_temporal = ctxt_length < 50
623
+ x = MMDiT(
624
+ x=x,
625
+ ctxt_input=ctxt_condition,
626
+ vtxt_input=vtxt_condition,
627
+ timesteps=timesteps,
628
+ x_mask_temporal=x_mask_temporal,
629
+ ctxt_mask_temporal=ctxt_mask_temporal,
630
+ )
631
+ assert x.shape == (
632
+ bsz,
633
+ seq_len,
634
+ input_dim,
635
+ ), f"unexpected output shape: {x.shape}, which should be ({bsz}, {seq_len}, {input_dim})"
636
+ print(x.shape)
hymotion/network/modulate_layers.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+ from torch import Tensor
5
+
6
+ from .bricks import get_activation_layer
7
+
8
+
9
+ class ModulateDiT(nn.Module):
10
+ def __init__(self, feat_dim: int, factor: int, act_type: str = "silu"):
11
+ super().__init__()
12
+ self.act = get_activation_layer(act_type)()
13
+ self.linear = nn.Linear(feat_dim, factor * feat_dim, bias=True)
14
+ nn.init.zeros_(self.linear.weight)
15
+ nn.init.zeros_(self.linear.bias)
16
+
17
+ def forward(self, x: Tensor) -> Tensor:
18
+ return self.linear(self.act(x))
19
+
20
+
21
+ def modulate(x: Tensor, shift: Optional[Tensor] = None, scale: Optional[Tensor] = None) -> Tensor:
22
+ if shift is not None and scale is not None:
23
+ assert len(x.shape) == len(shift.shape) == len(scale.shape), (
24
+ "x, shift, scale must have the same number of dimensions, "
25
+ f"but got x.shape: {x.shape}, "
26
+ f"shift.shape: {shift.shape} "
27
+ f"and scale.shape: {scale.shape}"
28
+ )
29
+ if shift is not None and scale is not None:
30
+ return x * (1 + scale) + shift
31
+ elif shift is not None:
32
+ return x + shift
33
+ elif scale is not None:
34
+ return x * (1 + scale)
35
+ else:
36
+ return x
37
+
38
+
39
+ def apply_gate(x: Tensor, gate: Optional[Tensor] = None, tanh: bool = False) -> Tensor:
40
+ if gate is not None:
41
+ assert len(x.shape) == len(
42
+ gate.shape
43
+ ), f"x, gate must have the same number of dimensions, but got {x.shape} and {gate.shape}"
44
+ if gate is None:
45
+ return x
46
+ if tanh:
47
+ return x * gate.tanh()
48
+ else:
49
+ return x * gate
hymotion/network/positional_encoding.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+
9
+ class RotaryEmbedding(nn.Module):
10
+ def __init__(
11
+ self,
12
+ num_feats: int,
13
+ max_seq_len: Union[Tensor, int],
14
+ temperature: int = 10000,
15
+ use_real: bool = False,
16
+ theta_rescale_factor: float = 1.0,
17
+ interpolation_factor: float = 1.0,
18
+ ) -> None:
19
+ super(RotaryEmbedding, self).__init__()
20
+ assert num_feats % 2 == 0, "num_feats (head_dim) must be even for RoPE."
21
+ self.num_feats = num_feats
22
+ self.max_seq_len = max_seq_len
23
+ self.temperature = temperature
24
+ self.use_real = use_real
25
+ self.theta_rescale_factor = theta_rescale_factor
26
+ self.interpolation_factor = interpolation_factor
27
+
28
+ if isinstance(max_seq_len, int):
29
+ max_seq_len = torch.arange(max_seq_len).float()
30
+
31
+ if theta_rescale_factor != 1.0:
32
+ temperature *= theta_rescale_factor ** (self.num_feats / (self.num_feats - 2))
33
+ dim_t = torch.arange(0, self.num_feats, 2, dtype=torch.float32)
34
+ freqs = 1.0 / (temperature ** (2 * torch.div(dim_t, 2, rounding_mode="trunc") / self.num_feats)) # [D/2]
35
+ freqs = torch.outer(max_seq_len.float() * interpolation_factor, freqs) # [S, D/2]
36
+ if use_real:
37
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
38
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
39
+ self.freqs_cis = (freqs_cos, freqs_sin)
40
+ else:
41
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
42
+ self.freqs_cis = freqs_cis
43
+
44
+ def reshape_for_broadcast(
45
+ self, freqs_cis: Union[Tensor, Tuple[Tensor, Tensor]], x: Tensor
46
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
47
+ ndim = x.ndim
48
+ assert 0 <= 1 < ndim
49
+
50
+ if isinstance(freqs_cis, tuple):
51
+ # freqs_cis: (cos, sin) in real space
52
+ assert (
53
+ freqs_cis[0].shape[-1] == x.shape[-1]
54
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape} on the head_dim dimension"
55
+ assert freqs_cis[0].shape[0] >= x.shape[1], (
56
+ f"freqs_cis shape {freqs_cis[0].shape} should be larger than or equal to "
57
+ f"x shape {x.shape} on the time dimension"
58
+ )
59
+ shape = []
60
+ for i, d in enumerate(x.shape):
61
+ if i == 1:
62
+ shape.append(-1)
63
+ elif i == ndim - 1:
64
+ shape.append(d)
65
+ else:
66
+ shape.append(1)
67
+ return (
68
+ freqs_cis[0].view(*shape)[:, : x.shape[1], ...],
69
+ freqs_cis[1].view(*shape)[:, : x.shape[1], ...],
70
+ )
71
+ else:
72
+ # freqs_cis: values in complex space
73
+ assert (
74
+ freqs_cis.shape[-1] == x.shape[-1]
75
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape} on the head_dim dimension"
76
+ assert freqs_cis.shape[0] >= x.shape[1], (
77
+ f"freqs_cis shape {freqs_cis.shape} should be larger than or equal to "
78
+ f"x shape {x.shape} on the time dimension"
79
+ )
80
+ shape = []
81
+ for i, d in enumerate(x.shape):
82
+ if i == 1:
83
+ shape.append(-1)
84
+ elif i == ndim - 1:
85
+ shape.append(d)
86
+ else:
87
+ shape.append(1)
88
+ return freqs_cis.view(*shape)[:, : x.shape[1], ...]
89
+
90
+ def rotate_half(self, x: Tensor) -> Tensor:
91
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
92
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
93
+
94
+ def apply_rotary_emb(self, xq: Tensor, xk: Tensor) -> Tuple[Tensor, Tensor]:
95
+ xk_out = None
96
+ if isinstance(self.freqs_cis, tuple):
97
+ cos, sin = self.reshape_for_broadcast(self.freqs_cis, xq) # [B, L, H, D]
98
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
99
+ # real * cos - imag * sin
100
+ # imag * cos + real * sin
101
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
102
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
103
+ else:
104
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
105
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
106
+ freqs_cis = self.reshape_for_broadcast(self.freqs_cis, xq_)
107
+ # Handle device transfer based on return type
108
+ if isinstance(freqs_cis, tuple):
109
+ freqs_cis = (freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device))
110
+ else:
111
+ freqs_cis = freqs_cis.to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
112
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
113
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
114
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
115
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
116
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
117
+
118
+ return xq_out, xk_out
119
+
120
+ def __repr__(self) -> str:
121
+ repr_str = self.__class__.__name__
122
+ repr_str += f"(num_feats={self.num_feats}, "
123
+ repr_str += f"max_seq_len={self.max_seq_len}, "
124
+ repr_str += f"temperature={self.temperature}, "
125
+ repr_str += f"use_real={self.use_real}, "
126
+ repr_str += f"theta_rescale_factor={self.theta_rescale_factor}, "
127
+ repr_str += f"interpolation_factor={self.interpolation_factor})"
128
+ return repr_str
129
+
130
+
131
+ class PositionalEncoding(nn.Module):
132
+ def __init__(self, num_feats: int, dropout: float = 0.1, max_len: int = 5000):
133
+ super(PositionalEncoding, self).__init__()
134
+ self.dropout = nn.Dropout(p=dropout)
135
+
136
+ pe = torch.zeros(max_len, num_feats)
137
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
138
+ div_term = torch.exp(torch.arange(0, num_feats, 2).float() * (-np.log(10000.0) / num_feats))
139
+ pe[:, 0::2] = torch.sin(position * div_term)
140
+ pe[:, 1::2] = torch.cos(position * div_term)
141
+ pe = pe.unsqueeze(0) # shape of [1, L, D]
142
+ self.register_buffer("pe", pe)
143
+
144
+ def forward(self, x: Tensor) -> Tensor:
145
+ x = x + self.pe[:, : x.shape[1], :] # shape of [B, L, D]
146
+ return self.dropout(x)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ # python -m hymotion.network.positional_encoding
151
+ num_feats = 32
152
+ rope = RotaryEmbedding(num_feats=num_feats, max_seq_len=5000, use_real=True)
153
+ x = torch.ones(1, 360, 1, num_feats)
154
+ text = torch.ones(1, 256, 1, num_feats)
155
+ q1, k1 = x.clone(), x.clone()
156
+ q2, k2 = text.clone(), text.clone()
157
+ print(x.shape)
158
+ # q1, k1 = rope.apply_rotary_emb(q1, k1)
159
+ # q2, k2 = rope.apply_rotary_emb(q2, k2)
160
+ q = torch.cat([q1, q2], dim=1)
161
+ k = torch.cat([k1, k2], dim=1)
162
+ q, k = rope.apply_rotary_emb(q, k)
163
+ q, k = q[0, :, 0, :], k[0, :, 0, :]
164
+ attn = (q[:, None] * k[None, :]).sum(dim=-1)
165
+ # softmax
166
+ # attn = torch.softmax(attn, dim=-1)
167
+ attn = attn.cpu().numpy()
168
+
169
+ import matplotlib.pyplot as plt
170
+
171
+ plt.imshow(attn, cmap="hot")
172
+ plt.colorbar()
173
+ plt.savefig("attn.png")
174
+ breakpoint()
hymotion/network/text_encoders/model_constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION",
3
+ ]
4
+
5
+
6
+ PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION = """
7
+ Summarize human motion only from the user text for representation: action categories, key body-part movements, order/transitions, trajectory/direction, posture; include style/emotion/speed only if present. Explicitly capture laterality (left/right) when mentioned; do not guess. If multiple actions are described, indicate the count of distinct actions (e.g., actions=3) and their order. Do not invent missing info. Keep one concise paragraph.
8
+ """
hymotion/network/text_encoders/text_encoder.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ )
13
+
14
+ from ...utils.type_converter import get_module_device
15
+ from .model_constants import PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION
16
+
17
+ USE_HF_MODELS = os.environ.get("USE_HF_MODELS", "0") == "1"
18
+
19
+ if USE_HF_MODELS:
20
+ QWEN_PATH = "Qwen/Qwen3-8B"
21
+ CLIP_PATH = "openai/clip-vit-large-patch14"
22
+ else:
23
+ QWEN_PATH = "ckpts/Qwen3-8B"
24
+ CLIP_PATH = "ckpts/clip-vit-large-patch14"
25
+
26
+ LLM_ENCODER_LAYOUT = {
27
+ "qwen3": {
28
+ "module_path": QWEN_PATH,
29
+ "template": [
30
+ {"role": "system", "content": f"{PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION}"},
31
+ {"role": "user", "content": "{}"},
32
+ ],
33
+ "crop_start": 0,
34
+ "tokenizer_class": AutoTokenizer,
35
+ "text_encoder_class": AutoModelForCausalLM,
36
+ },
37
+ }
38
+
39
+ SENTENCE_EMB_LAYOUT = {
40
+ "clipl": {
41
+ "module_path": CLIP_PATH,
42
+ "tokenizer_class": CLIPTokenizer,
43
+ "text_encoder_class": CLIPTextModel,
44
+ "pooling_mode": "pooler_output",
45
+ "max_length": 77,
46
+ },
47
+ }
48
+
49
+
50
+ class HYTextModel(nn.Module):
51
+ def __init__(
52
+ self,
53
+ llm_type: Optional[str] = "qwen3",
54
+ max_length_llm: int = 512,
55
+ sentence_emb_type: Optional[str] = "clipl",
56
+ max_length_sentence_emb: int = 77,
57
+ enable_llm_padding: bool = True,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.text_encoder_type = "hy_text_model"
61
+
62
+ self.sentence_emb_type = sentence_emb_type
63
+ self.sentence_emb_text_encoder = None
64
+ self.sentence_emb_tokenizer = None
65
+ self.vtxt_dim = 0
66
+ if sentence_emb_type is not None:
67
+ assert sentence_emb_type in SENTENCE_EMB_LAYOUT, f"Unsupported sentence embedding type: {sentence_emb_type}"
68
+ self.max_length_sentence_emb = max_length_sentence_emb or SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
69
+ "max_length", 77
70
+ )
71
+ self._sentence_emb_pooling_mode = SENTENCE_EMB_LAYOUT[sentence_emb_type].get(
72
+ "pooling_mode", "pooler_output"
73
+ )
74
+ tokenizer_kwargs = SENTENCE_EMB_LAYOUT[sentence_emb_type].get("tokenizer_kwargs", {})
75
+
76
+ self.sentence_emb_tokenizer = SENTENCE_EMB_LAYOUT[sentence_emb_type]["tokenizer_class"].from_pretrained(
77
+ SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"],
78
+ max_length=self.max_length_sentence_emb,
79
+ **tokenizer_kwargs,
80
+ )
81
+ self.sentence_emb_text_encoder = SENTENCE_EMB_LAYOUT[sentence_emb_type][
82
+ "text_encoder_class"
83
+ ].from_pretrained(SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"])
84
+ self.sentence_emb_text_encoder = self.sentence_emb_text_encoder.eval().requires_grad_(False)
85
+ self.vtxt_dim = self.sentence_emb_text_encoder.config.hidden_size
86
+
87
+ self.llm_type = llm_type
88
+ self.llm_text_encoder = None
89
+ self.llm_tokenizer = None
90
+ self.ctxt_dim = 0
91
+ self.crop_start = 0
92
+ self.max_length_llm = max_length_llm
93
+ if llm_type is not None:
94
+ assert llm_type in LLM_ENCODER_LAYOUT, f"Unsupported LLM type: {llm_type}"
95
+ self._orig_max_length_llm = max_length_llm
96
+ self.enable_llm_padding = enable_llm_padding
97
+ self.llm_tokenizer = LLM_ENCODER_LAYOUT[llm_type]["tokenizer_class"].from_pretrained(
98
+ LLM_ENCODER_LAYOUT[llm_type]["module_path"],
99
+ padding_side="right",
100
+ )
101
+ self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
102
+ LLM_ENCODER_LAYOUT[llm_type]["module_path"], low_cpu_mem_usage=True
103
+ )
104
+ self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
105
+ self.ctxt_dim = self.llm_text_encoder.config.hidden_size
106
+
107
+ self.crop_start = self._compute_crop_start()
108
+ self.max_length_llm = self._orig_max_length_llm + self.crop_start
109
+
110
+ @torch.no_grad()
111
+ def encode_llm(self, text: List[str]) -> Tuple[Tensor, Tensor]:
112
+ if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
113
+ raise ValueError("LLM model not initialized")
114
+
115
+ device = get_module_device(self)
116
+ llm_text = [
117
+ (
118
+ self.llm_tokenizer.apply_chat_template(
119
+ self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"]),
120
+ tokenize=False,
121
+ add_generation_prompt=False,
122
+ enable_thinking=False,
123
+ )
124
+ if self.llm_type == "qwen3"
125
+ else self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
126
+ )
127
+ for one_text in text
128
+ ]
129
+ padding_mode = "max_length" if self.enable_llm_padding else False
130
+ llm_batch_encoding = self.llm_tokenizer(
131
+ llm_text,
132
+ return_length=False,
133
+ return_overflowing_tokens=False,
134
+ truncation=True,
135
+ return_attention_mask=True,
136
+ max_length=self.max_length_llm, # = crop_start + _orig_max_length_llm
137
+ padding=padding_mode,
138
+ return_tensors="pt",
139
+ )
140
+ llm_outputs = (
141
+ self.llm_text_encoder(
142
+ input_ids=llm_batch_encoding["input_ids"].to(device),
143
+ attention_mask=llm_batch_encoding["attention_mask"].to(device),
144
+ output_hidden_states=True,
145
+ )
146
+ if self.llm_type == "qwen3"
147
+ else self.llm_text_encoder(
148
+ input_ids=llm_batch_encoding["input_ids"].to(device),
149
+ attention_mask=llm_batch_encoding["attention_mask"].to(device),
150
+ )
151
+ )
152
+ if self.llm_type == "qwen3":
153
+ ctxt_raw = llm_outputs.hidden_states[-1]
154
+ else:
155
+ ctxt_raw = llm_outputs.last_hidden_state
156
+
157
+ start = self.crop_start
158
+ end = start + self._orig_max_length_llm
159
+ ctxt_raw = ctxt_raw[:, start:end].contiguous() # [bs, _orig_max_length_llm, hidden]
160
+ ctxt_length = (llm_batch_encoding["attention_mask"].sum(dim=-1).to(device) - start).clamp(
161
+ min=0, max=self._orig_max_length_llm
162
+ )
163
+ return ctxt_raw, ctxt_length
164
+
165
+ @torch.no_grad()
166
+ def encode_sentence_emb(self, text: List[str]) -> Tensor:
167
+ if (
168
+ self.sentence_emb_type is None
169
+ or self.sentence_emb_text_encoder is None
170
+ or self.sentence_emb_tokenizer is None
171
+ ):
172
+ raise ValueError("Sentence embedding model not initialized")
173
+
174
+ device = get_module_device(self)
175
+ enc = self.sentence_emb_tokenizer(
176
+ text,
177
+ return_length=False,
178
+ return_overflowing_tokens=False,
179
+ truncation=True,
180
+ return_attention_mask=True,
181
+ max_length=self.max_length_sentence_emb,
182
+ padding=True,
183
+ return_tensors="pt",
184
+ )
185
+ out = self.sentence_emb_text_encoder(
186
+ input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)
187
+ )
188
+ if self._sentence_emb_pooling_mode == "pooler_output":
189
+ # Pooler output pooling (clip-vit-large-patch14 等)
190
+ if hasattr(out, "pooler_output") and out.pooler_output is not None:
191
+ vtxt_raw = out.pooler_output.unsqueeze(1)
192
+ else:
193
+ vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
194
+ elif self._sentence_emb_pooling_mode == "mean":
195
+ vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state)
196
+ elif self._sentence_emb_pooling_mode == "last_token":
197
+ vtxt_raw = self._last_token_pool(out.last_hidden_state, enc["attention_mask"].to(device))
198
+ else:
199
+ raise ValueError(f"Unknown pooling mode: {self._sentence_emb_pooling_mode}")
200
+
201
+ return vtxt_raw
202
+
203
+ def encode(self, text: List[str]) -> Tuple[Tensor, Tensor, Tensor]:
204
+ ctxt_raw, ctxt_length = self.encode_llm(text=text)
205
+ vtxt_raw = self.encode_sentence_emb(text=text)
206
+ return vtxt_raw, ctxt_raw, ctxt_length
207
+
208
+ @staticmethod
209
+ def apply_text_to_template(text: str, template: Union[str, list]) -> Union[str, list]:
210
+ if isinstance(template, str):
211
+ return template.format(text)
212
+ elif isinstance(template, list):
213
+ return [
214
+ {"role": "system", "content": f"{template[0]['content']}"},
215
+ {"role": "user", "content": f"{text}"},
216
+ ]
217
+ else:
218
+ raise TypeError(f"Unsupported template type: {type(template)}")
219
+
220
+ def _compute_crop_start(self) -> int:
221
+ if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None:
222
+ raise ValueError("LLM model not initialized")
223
+
224
+ def _find_subseq(a: str, b: str) -> int:
225
+ for i in range(0, len(a) - len(b) + 1):
226
+ if a[i : i + len(b)] == b:
227
+ return i
228
+ return -1
229
+
230
+ marker = "<BOC>"
231
+ if self.llm_type == "qwen3":
232
+ msgs = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
233
+ s = self.llm_tokenizer.apply_chat_template(
234
+ msgs, tokenize=False, add_generation_prompt=False, enable_thinking=False
235
+ )
236
+ else:
237
+ s = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"])
238
+ full_ids = self.llm_tokenizer(s, return_tensors="pt", add_special_tokens=True)["input_ids"][0].tolist()
239
+ marker_ids = self.llm_tokenizer(marker, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist()
240
+ pos = _find_subseq(full_ids, marker_ids)
241
+ if pos >= 0:
242
+ return pos
243
+ else:
244
+ return max(0, len(full_ids) - 1)
245
+
246
+ def _pad_or_truncate_tensor(self, tensor: Tensor, target_length: int, dim: int = 0) -> Tensor:
247
+ current_length = tensor.shape[dim]
248
+ if current_length > target_length:
249
+ return tensor.narrow(dim, 0, target_length)
250
+ elif current_length < target_length:
251
+ pad_shape = list(tensor.shape)
252
+ pad_shape[dim] = target_length - current_length
253
+ padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor.narrow(dim, -1, 1)
254
+ return torch.cat([tensor, padding], dim=dim)
255
+ return tensor
256
+
257
+ def _encode_pooling(self, attention_mask: Tensor, token_embeddings: Tensor) -> Tensor:
258
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
259
+ sentence_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
260
+ input_mask_expanded.sum(1), min=1e-9
261
+ )
262
+ vtxt_raw = nn.functional.normalize(sentence_embeddings, p=2, dim=1).unsqueeze(1) # shape of [bs, 1, D]
263
+ return vtxt_raw
264
+
265
+ def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
266
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
267
+ if left_padding:
268
+ vtxt_raw = last_hidden_states[:, -1]
269
+ else:
270
+ sequence_lengths = attention_mask.sum(dim=1) - 1
271
+ batch_size = last_hidden_states.shape[0]
272
+ vtxt_raw = last_hidden_states[
273
+ torch.arange(batch_size, device=last_hidden_states.device),
274
+ sequence_lengths,
275
+ ]
276
+ vtxt_raw = nn.functional.normalize(vtxt_raw, p=2, dim=-1).unsqueeze(1) # shape of [bs, 1, D]
277
+ return vtxt_raw
278
+
279
+
280
+ if __name__ == "__main__":
281
+ # python -m hymotion.network.text_encoders.text_encoder
282
+ text_encoder = HYTextModel(llm_type="qwen3", max_length_llm=5)
283
+ vtxt_raw, ctxt_raw, ctxt_length = text_encoder.encode(["Hello, world!"])
284
+ print(vtxt_raw.shape, ctxt_raw.shape, ctxt_length)
285
+
286
+ crop_start = text_encoder._compute_crop_start()
287
+ print(f"crop_start: {crop_start} when using {text_encoder.llm_type}")
288
+
289
+ assert (
290
+ vtxt_raw.shape[1:] == (1, text_encoder.vtxt_dim)
291
+ and ctxt_raw.shape[1:] == (text_encoder._orig_max_length_llm, text_encoder.ctxt_dim)
292
+ and torch.all((ctxt_length >= 0) & (ctxt_length <= text_encoder._orig_max_length_llm))
293
+ ), f"Got unexpected output shape: {vtxt_raw.shape}, {ctxt_raw.shape}, {ctxt_length}"
hymotion/network/token_refiner.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from torch import Tensor
7
+
8
+ from .attention import attention
9
+ from .bricks import get_norm_layer
10
+ from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder
11
+ from .modulate_layers import ModulateDiT, apply_gate
12
+
13
+
14
+ class IndividualTokenRefinerBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ feat_dim: int,
18
+ num_heads: int,
19
+ mlp_ratio: float = 4.0,
20
+ dropout: float = 0.0,
21
+ mlp_act_type: str = "silu",
22
+ qk_norm_type: str = "layer",
23
+ qkv_bias: bool = True,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.feat_dim = feat_dim
27
+ self.num_heads = num_heads
28
+ self.mlp_ratio = mlp_ratio
29
+ self.dropout = dropout
30
+ assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}"
31
+ self.head_dim = feat_dim // num_heads
32
+
33
+ self.mlp_hidden_dim = int(feat_dim * mlp_ratio)
34
+
35
+ self.norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
36
+ self.self_attn_qkv = nn.Linear(feat_dim, feat_dim * 3, bias=qkv_bias)
37
+ self.self_attn_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
38
+ self.self_attn_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6)
39
+ self.self_attn_proj = nn.Linear(feat_dim, feat_dim, bias=qkv_bias)
40
+
41
+ self.norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6)
42
+
43
+ self.mlp = MLP(
44
+ in_dim=feat_dim,
45
+ feat_dim=self.mlp_hidden_dim,
46
+ act_type=mlp_act_type,
47
+ drop=dropout,
48
+ )
49
+
50
+ self.adaLN_modulation = ModulateDiT(
51
+ feat_dim=feat_dim,
52
+ factor=2,
53
+ act_type="silu",
54
+ )
55
+
56
+ def forward(self, x: Tensor, c: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
57
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)
58
+ norm_x = self.norm1(x)
59
+ qkv = self.self_attn_qkv(norm_x)
60
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
61
+ # Apply QK-Norm if needed
62
+ q = self.self_attn_q_norm(q).to(v)
63
+ k = self.self_attn_k_norm(k).to(v)
64
+ # Self-Attention
65
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
66
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
67
+ # FFN Layer
68
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
69
+ return x
70
+
71
+
72
+ class IndividualTokenRefiner(nn.Module):
73
+ def __init__(
74
+ self,
75
+ feat_dim: int,
76
+ num_heads: int,
77
+ num_layers: int,
78
+ mlp_ratio: float = 4.0,
79
+ dropout: float = 0.0,
80
+ mlp_act_type: str = "silu",
81
+ qk_norm_type: str = "layer",
82
+ qkv_bias: bool = True,
83
+ ) -> None:
84
+ super().__init__()
85
+ self.blocks = nn.ModuleList(
86
+ [
87
+ IndividualTokenRefinerBlock(
88
+ feat_dim=feat_dim,
89
+ num_heads=num_heads,
90
+ mlp_ratio=mlp_ratio,
91
+ dropout=dropout,
92
+ mlp_act_type=mlp_act_type,
93
+ qk_norm_type=qk_norm_type,
94
+ qkv_bias=qkv_bias,
95
+ )
96
+ for _ in range(num_layers)
97
+ ]
98
+ )
99
+
100
+ def forward(self, x: Tensor, c: Tensor, mask: Optional[Tensor] = None) -> Tensor:
101
+ self_attn_mask = None
102
+ if mask is not None:
103
+ batch_size = mask.shape[0]
104
+ seq_len = mask.shape[1]
105
+ mask = mask.to(x.device)
106
+ # batch_size x 1 x seq_len x seq_len
107
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
108
+ # batch_size x 1 x seq_len x seq_len
109
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
110
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
111
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
112
+ # avoids self-attention weight being NaN for padding tokens
113
+ # assume the shape of self_attn_mask is [B, H, Q, K] and this is self-attention (Q==K==L)
114
+ L = self_attn_mask.size(-1)
115
+ diag = torch.eye(L, dtype=torch.bool, device=self_attn_mask.device).view(1, 1, L, L) # [1,1,L,L]
116
+ # mark which query row is "all False" (no visible key)
117
+ all_false = ~self_attn_mask.any(dim=-1, keepdim=False) # [B, H, Q]
118
+ # expand to [B, H, Q, K], only for these rows, back to diagonal visible
119
+ all_false = all_false.unsqueeze(-1).expand(-1, -1, -1, L)
120
+ self_attn_mask = torch.where(all_false, diag.expand_as(self_attn_mask), self_attn_mask)
121
+
122
+ if self_attn_mask is not None:
123
+ self_attn_mask = torch.where(
124
+ self_attn_mask,
125
+ torch.zeros_like(self_attn_mask, dtype=torch.float),
126
+ torch.full_like(self_attn_mask, float("-inf"), dtype=torch.float),
127
+ )
128
+ for block in self.blocks:
129
+ x = block(x, c, self_attn_mask)
130
+ return x
131
+
132
+
133
+ class SingleTokenRefiner(nn.Module):
134
+ def __init__(
135
+ self,
136
+ input_dim: int,
137
+ feat_dim: int,
138
+ num_heads: int,
139
+ num_layers: int,
140
+ mlp_ratio: float = 4.0,
141
+ dropout: float = 0.0,
142
+ mlp_act_type: str = "silu",
143
+ qk_norm_type: str = "layer",
144
+ qkv_bias: bool = True,
145
+ attn_mode: str = "torch",
146
+ **kwargs,
147
+ ) -> None:
148
+ super().__init__()
149
+ self.attn_mode = attn_mode
150
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
151
+
152
+ self.input_embedder = nn.Linear(input_dim, feat_dim, bias=True)
153
+ self.context_encoder = MLPEncoder(
154
+ in_dim=feat_dim,
155
+ feat_dim=feat_dim,
156
+ num_layers=2,
157
+ act_type=mlp_act_type,
158
+ )
159
+ self.timestep_encoder = TimestepEmbeddingEncoder(
160
+ embedding_dim=feat_dim,
161
+ feat_dim=feat_dim,
162
+ act_type=mlp_act_type,
163
+ )
164
+
165
+ self.individual_token_refiner = IndividualTokenRefiner(
166
+ feat_dim=feat_dim,
167
+ num_heads=num_heads,
168
+ num_layers=num_layers,
169
+ mlp_ratio=mlp_ratio,
170
+ dropout=dropout,
171
+ mlp_act_type=mlp_act_type,
172
+ qk_norm_type=qk_norm_type,
173
+ qkv_bias=qkv_bias,
174
+ )
175
+
176
+ def forward(self, x: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Tensor:
177
+ timestep_aware_representations = self.timestep_encoder(t)
178
+
179
+ if mask is None:
180
+ context_aware_representations = x.mean(dim=1)
181
+ else:
182
+ mask_float = mask.float().unsqueeze(-1)
183
+ denom = mask_float.sum(dim=1).clamp_min(1e-6)
184
+ context_aware_representations = (x * mask_float).sum(dim=1) / denom
185
+ context_aware_representations = self.context_encoder(context_aware_representations).unsqueeze(1)
186
+ c = timestep_aware_representations + context_aware_representations
187
+
188
+ x = self.input_embedder(x)
189
+
190
+ x = self.individual_token_refiner(x, c, mask)
191
+
192
+ return x
hymotion/pipeline/body_model.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+
10
+ from ..utils.geometry import (
11
+ rot6d_to_rotation_matrix,
12
+ rotation_matrix_to_angle_axis,
13
+ )
14
+
15
+ # fmt: off
16
+ LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909,
17
+ -0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192,
18
+ -0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069,
19
+ -0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579,
20
+ -0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998,
21
+ 0.0266, 0.0529, 0.5356, 0.0460, -0.2774]
22
+ RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909,
23
+ 0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192,
24
+ 0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069,
25
+ -0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579,
26
+ 0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998,
27
+ -0.0266, -0.0529, 0.5356, -0.0460, 0.2774]
28
+ # fmt: on
29
+
30
+
31
+ def to_tensor(array, dtype=torch.float32, device=torch.device("cpu")):
32
+ if "torch.tensor" not in str(type(array)):
33
+ return torch.tensor(array, dtype=dtype).to(device)
34
+ else:
35
+ return array.to(device)
36
+
37
+
38
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
39
+ """Calculates the rotation matrices for a batch of rotation vectors
40
+ Parameters
41
+ ----------
42
+ rot_vecs: torch.tensor Nx3
43
+ array of N axis-angle vectors
44
+ Returns
45
+ -------
46
+ R: torch.tensor Nx3x3
47
+ The rotation matrices for the given axis-angle parameters
48
+ """
49
+ if len(rot_vecs.shape) > 2:
50
+ rot_vec_ori = rot_vecs
51
+ rot_vecs = rot_vecs.view(-1, 3)
52
+ else:
53
+ rot_vec_ori = None
54
+ batch_size = rot_vecs.shape[0]
55
+ device = rot_vecs.device
56
+
57
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
58
+ rot_dir = rot_vecs / angle
59
+
60
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
61
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
62
+
63
+ # Bx1 arrays
64
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
65
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
66
+
67
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
68
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3))
69
+
70
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
71
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
72
+ if rot_vec_ori is not None:
73
+ rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3)
74
+ return rot_mat
75
+
76
+
77
+ def load_model_data(model_path):
78
+ """
79
+ Load wooden model data from binary files.
80
+
81
+ Args:
82
+ model_path: path to the directory containing .bin files
83
+
84
+ Returns:
85
+ dict containing:
86
+ - v_template: (V, 3) vertex template
87
+ - j_template: (J, 3) joint template
88
+ - skin_weights: (V, 4) skin weights
89
+ - skin_indices: (V, 4) skin indices
90
+ - parents: (J,) parent indices (kintree)
91
+ - faces: (F, 3) face indices
92
+ - joint_names: list of joint names
93
+ """
94
+ model_path = Path(model_path)
95
+
96
+ # Load vertex template: (V*3,) -> (V, 3)
97
+ with open(model_path / "v_template.bin", "rb") as f:
98
+ v_template_flat = np.frombuffer(f.read(), dtype=np.float32)
99
+ num_verts = len(v_template_flat) // 3
100
+ v_template = v_template_flat.reshape(num_verts, 3)
101
+
102
+ # Load joint template: (J*3,) -> (J, 3)
103
+ with open(model_path / "j_template.bin", "rb") as f:
104
+ j_template_flat = np.frombuffer(f.read(), dtype=np.float32)
105
+ num_joints = len(j_template_flat) // 3
106
+ j_template = j_template_flat.reshape(num_joints, 3)
107
+
108
+ # Load skin weights: (V*4,) -> (V, 4), 4 bones per vertex
109
+ with open(model_path / "skinWeights.bin", "rb") as f:
110
+ skin_weights_flat = np.frombuffer(f.read(), dtype=np.float32)
111
+ skin_weights = skin_weights_flat.reshape(num_verts, 4)
112
+
113
+ # Load skin indices: (V*4,) -> (V, 4), 4 bone indices per vertex
114
+ with open(model_path / "skinIndice.bin", "rb") as f:
115
+ skin_indices_flat = np.frombuffer(f.read(), dtype=np.uint16)
116
+ skin_indices = skin_indices_flat.reshape(num_verts, 4).astype(np.int64)
117
+
118
+ # Load kintree (parent indices): (J,)
119
+ with open(model_path / "kintree.bin", "rb") as f:
120
+ parents = np.frombuffer(f.read(), dtype=np.int32)
121
+
122
+ # Load faces
123
+ with open(model_path / "faces.bin", "rb") as f:
124
+ faces_flat = np.frombuffer(f.read(), dtype=np.uint16)
125
+ faces = faces_flat.reshape(-1, 3)
126
+
127
+ # Load joint names
128
+ joint_names_path = model_path / "joint_names.json"
129
+ if joint_names_path.exists():
130
+ with open(joint_names_path, "r") as f:
131
+ joint_names = json.load(f)
132
+ else:
133
+ joint_names = [f"Joint_{i}" for i in range(num_joints)]
134
+
135
+ return {
136
+ "v_template": v_template,
137
+ "j_template": j_template,
138
+ "skin_weights": skin_weights,
139
+ "skin_indices": skin_indices,
140
+ "parents": parents,
141
+ "faces": faces,
142
+ "joint_names": joint_names,
143
+ "num_joints": num_joints,
144
+ "num_verts": num_verts,
145
+ }
146
+
147
+
148
+ def simple_lbs(v_template, rot_mats, joints, parents, skin_weights, skin_indices):
149
+ """
150
+ Simple Linear Blend Skinning without shape blending.
151
+
152
+ Args:
153
+ v_template: (V, 3) template vertices
154
+ rot_mats: (B, J, 3, 3) rotation matrices for each joint
155
+ joints: (J, 3) joint positions in rest pose
156
+ parents: (J,) parent indices for each joint
157
+ skin_weights: (V, 4) skin weights for 4 bones per vertex
158
+ skin_indices: (V, 4) bone indices for 4 bones per vertex
159
+
160
+ Returns:
161
+ vertices: (B, V, 3) transformed vertices
162
+ posed_joints: (B, J, 3) transformed joint positions
163
+ """
164
+ batch_size = rot_mats.shape[0]
165
+ num_joints = rot_mats.shape[1]
166
+ num_verts = v_template.shape[0]
167
+ device = rot_mats.device
168
+ dtype = rot_mats.dtype
169
+
170
+ # Compute relative joint positions
171
+ rel_joints = joints.clone()
172
+ rel_joints[1:] = joints[1:] - joints[parents[1:]]
173
+
174
+ # Build transformation chain: transforms_mat (B, J, 4, 4)
175
+ transforms_mat = torch.zeros(batch_size, num_joints, 4, 4, device=device, dtype=dtype)
176
+ transforms_mat[..., :3, :3] = rot_mats
177
+ transforms_mat[..., :3, 3] = rel_joints.unsqueeze(0).expand(batch_size, -1, -1)
178
+ transforms_mat[..., 3, 3] = 1.0
179
+
180
+ # Forward kinematics: accumulate transforms from root to each joint
181
+ transform_chain = [transforms_mat[:, 0]]
182
+ for i in range(1, num_joints):
183
+ parent_idx = parents[i].item()
184
+ curr_transform = torch.bmm(transform_chain[parent_idx], transforms_mat[:, i])
185
+ transform_chain.append(curr_transform)
186
+
187
+ transforms = torch.stack(transform_chain, dim=1) # (B, J, 4, 4)
188
+
189
+ # Get posed joint positions
190
+ posed_joints = transforms[..., :3, 3].clone() # (B, J, 3)
191
+
192
+ # Compute relative transforms (for skinning)
193
+ # We need to subtract the rest pose joint positions from the transform
194
+ rel_transforms = transforms.clone()
195
+ joints_homo = F.pad(joints, [0, 1], value=0) # (J, 4)
196
+ transformed_rest = torch.einsum("bjcd,jd->bjc", transforms[..., :3, :], joints_homo)
197
+ rel_transforms[..., :3, 3] = transforms[..., :3, 3] - transformed_rest[..., :3]
198
+
199
+ # Apply skinning: gather transforms for each vertex's 4 bones
200
+ # skin_indices: (V, 4), skin_weights: (V, 4)
201
+ vertex_transforms = torch.zeros(batch_size, num_verts, 4, 4, 4, device=device, dtype=dtype)
202
+ for k in range(4):
203
+ bone_idx = skin_indices[:, k].long() # (V,)
204
+ vertex_transforms[:, :, k] = rel_transforms[:, bone_idx] # (B, V, 4, 4)
205
+
206
+ # Weight the transforms
207
+ skin_weights_expanded = skin_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1, V, 4, 1, 1)
208
+ skin_weights_expanded = skin_weights_expanded.expand(batch_size, -1, -1, 4, 4) # (B, V, 4, 4, 4)
209
+
210
+ weighted_transforms = (vertex_transforms * skin_weights_expanded).sum(dim=2) # (B, V, 4, 4)
211
+
212
+ # Apply to vertices
213
+ v_homo = F.pad(v_template, [0, 1], value=1.0) # (V, 4)
214
+ vertices = torch.einsum("bvcd,vd->bvc", weighted_transforms[..., :3, :], v_homo) # (B, V, 3)
215
+
216
+ return vertices, posed_joints
217
+
218
+
219
+ class WoodenMesh(torch.nn.Module):
220
+ """
221
+ Wooden character mesh model that loads from binary files.
222
+ Uses simple LBS without shape blending (fixed skeleton).
223
+ """
224
+
225
+ def __init__(self, model_path="scripts/gradio/static/assets/dump_wooden"):
226
+ torch.nn.Module.__init__(self)
227
+
228
+ # Load model data from .bin files
229
+ model = load_model_data(model_path)
230
+
231
+ # Register buffers like original SMPLMesh
232
+ v_template = to_tensor(model["v_template"])
233
+ self.register_buffer("v_template", v_template)
234
+
235
+ j_template = to_tensor(model["j_template"])
236
+ self.register_buffer("j_template", j_template)
237
+
238
+ skin_weights = to_tensor(model["skin_weights"])
239
+ self.register_buffer("skin_weights", skin_weights)
240
+
241
+ skin_indices = to_tensor(model["skin_indices"], dtype=torch.long)
242
+ self.register_buffer("skin_indices", skin_indices)
243
+
244
+ parents = to_tensor(model["parents"], dtype=torch.long)
245
+ self.register_buffer("parents", parents)
246
+
247
+ # Store non-buffer attributes
248
+ self.faces = model["faces"]
249
+ self.joint_names = model["joint_names"]
250
+ self.num_joints = model["num_joints"]
251
+ self.num_verts = model["num_verts"]
252
+
253
+ print(f"[WoodenMesh] Loaded model: {self.num_verts} vertices, {self.num_joints} joints")
254
+
255
+ def forward(self, params, fast_forward=False):
256
+ """
257
+ Forward pass to compute deformed vertices.
258
+
259
+ Args:
260
+ params: dict containing:
261
+ - 'poses': (B, J*3) axis-angle rotations, or
262
+ - 'rot6d': (B, J, 6) 6D rotation representations
263
+ - 'trans': (B, 3) optional translation
264
+
265
+ Returns:
266
+ dict with 'vertices' and 'vertices_wotrans'
267
+ """
268
+ if "poses" in params:
269
+ poses = params["poses"]
270
+ batch_size = poses.shape[0]
271
+ rot_mats = batch_rodrigues(poses.view(-1, 3)).view([batch_size, -1, 3, 3])
272
+ elif "rot6d" in params:
273
+ rot6d = params["rot6d"]
274
+ batch_size = rot6d.shape[0]
275
+ rot_mats = rot6d_to_rotation_matrix(rot6d).view([batch_size, -1, 3, 3])
276
+ else:
277
+ raise ValueError("poses or rot6d must be in params")
278
+
279
+ if rot_mats.shape[1] == 22:
280
+ eye = torch.eye(3, device=rot_mats.device, dtype=rot_mats.dtype)[None, None, :, :].repeat(
281
+ batch_size, 30, 1, 1
282
+ )
283
+ rot_mats = torch.cat([rot_mats, eye], dim=1) # (B, 22 + 30, 3, 3)
284
+
285
+ # Simple LBS (no shape blending, fixed skeleton)
286
+ vertices, posed_joints = simple_lbs(
287
+ self.v_template,
288
+ rot_mats,
289
+ self.j_template,
290
+ self.parents,
291
+ self.skin_weights,
292
+ self.skin_indices,
293
+ )
294
+
295
+ # Vertices without translation (for pose-level supervision)
296
+ vertices_wotrans = vertices
297
+
298
+ if "trans" in params:
299
+ trans = params["trans"]
300
+ vertices = vertices + trans[:, None, :]
301
+
302
+ return {
303
+ "vertices": vertices,
304
+ "vertices_wotrans": vertices_wotrans,
305
+ "keypoints3d": posed_joints,
306
+ }
307
+
308
+ def forward_batch(self, params):
309
+ assert "rot6d" in params and "trans" in params
310
+ rot6d = params["rot6d"]
311
+ trans = params["trans"]
312
+ bs, num_frames = rot6d.shape[:2]
313
+ rot6d_flat = rot6d.reshape(bs * num_frames, rot6d.shape[2], rot6d.shape[3])
314
+ trans_flat = trans.reshape(bs * num_frames, trans.shape[2])
315
+ result = self.forward(
316
+ {
317
+ "rot6d": rot6d_flat,
318
+ "trans": trans_flat,
319
+ }
320
+ )
321
+ out = {}
322
+ for key in result:
323
+ out[key] = result[key].reshape(bs, num_frames, *result[key].shape[1:])
324
+ return out
325
+
326
+
327
+ def construct_smpl_data_dict(
328
+ rot6d: Tensor,
329
+ transl: Tensor,
330
+ betas: Optional[Tensor] = None,
331
+ gender: str = "neutral",
332
+ use_default_hand_mean_pose: bool = False,
333
+ ) -> dict:
334
+ rotation_matrix = rot6d_to_rotation_matrix(rot6d)
335
+ angle_axis = rotation_matrix_to_angle_axis(rotation_matrix)
336
+ left_hand_mean_pose = (
337
+ torch.tensor(
338
+ LEFT_HAND_MEAN_AA,
339
+ device=angle_axis.device,
340
+ dtype=angle_axis.dtype,
341
+ )
342
+ .unsqueeze(0)
343
+ .repeat(angle_axis.shape[0], 1)
344
+ .reshape(angle_axis.shape[0], -1, 3)
345
+ )
346
+ right_hand_mean_pose = (
347
+ torch.tensor(
348
+ RIGHT_HAND_MEAN_AA,
349
+ device=angle_axis.device,
350
+ dtype=angle_axis.dtype,
351
+ )
352
+ .unsqueeze(0)
353
+ .repeat(angle_axis.shape[0], 1)
354
+ .reshape(angle_axis.shape[0], -1, 3)
355
+ )
356
+ if angle_axis.shape[1] == 22:
357
+ angle_axis = torch.cat(
358
+ [
359
+ angle_axis,
360
+ left_hand_mean_pose,
361
+ right_hand_mean_pose,
362
+ ],
363
+ dim=1,
364
+ )
365
+ elif angle_axis.shape[1] == 52:
366
+ if use_default_hand_mean_pose:
367
+ angle_axis = torch.cat(
368
+ [
369
+ angle_axis[:, :22],
370
+ left_hand_mean_pose,
371
+ right_hand_mean_pose,
372
+ ],
373
+ dim=1,
374
+ )
375
+ else:
376
+ angle_axis = angle_axis
377
+
378
+ assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}"
379
+ dump = {
380
+ "betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)),
381
+ "gender": gender,
382
+ "poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1),
383
+ "trans": transl.cpu().numpy(),
384
+ "mocap_framerate": 30,
385
+ "num_frames": angle_axis.shape[0],
386
+ "Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3],
387
+ }
388
+ return dump
389
+
390
+
391
+ if __name__ == "__main__":
392
+ # python -m hymotion.pipeline.body_model
393
+ model_path = "scripts/gradio/static/assets/dump_wooden"
394
+ model = WoodenMesh(model_path)
395
+ params = {
396
+ "rot6d": torch.randn(1, 52, 6),
397
+ "trans": torch.randn(1, 3),
398
+ }
399
+ result = model(params)
400
+ print(result.keys())
401
+ print(result["vertices"].shape)
402
+ print(result["vertices_wotrans"].shape)
403
+ print(result["keypoints3d"].shape)
404
+ params_batch = {
405
+ "rot6d": torch.randn(3, 100, 22, 6),
406
+ "trans": torch.randn(3, 100, 3),
407
+ }
408
+ result_batch = model.forward_batch(params_batch)
409
+ print(result_batch.keys())
410
+ print(result_batch["vertices"].shape)
411
+ print(result_batch["vertices_wotrans"].shape)
412
+ print(result_batch["keypoints3d"].shape)
hymotion/pipeline/motion_diffusion.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from copy import deepcopy
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from scipy.signal import savgol_filter
9
+ from torch import Tensor
10
+ from torchdiffeq import odeint
11
+
12
+ from ..utils.geometry import (
13
+ matrix_to_quaternion,
14
+ quaternion_fix_continuity,
15
+ quaternion_to_matrix,
16
+ rot6d_to_rotation_matrix,
17
+ rotation_matrix_to_rot6d,
18
+ )
19
+ from ..utils.loaders import load_object
20
+ from ..utils.motion_process import smooth_rotation
21
+ from ..utils.type_converter import get_module_device
22
+ from .body_model import WoodenMesh
23
+
24
+
25
+ def length_to_mask(lengths: Tensor, max_len: int) -> Tensor:
26
+ """
27
+ lengths: (B, 1)
28
+ max_len: int
29
+ Returns: (B, max_len)
30
+ """
31
+ assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
32
+ if lengths.ndim == 1:
33
+ lengths = lengths.unsqueeze(1)
34
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths
35
+ return mask
36
+
37
+
38
+ def start_end_frame_to_mask(start_frame: Tensor, end_frame: Tensor, max_len: int) -> Tensor:
39
+ # 生成一个 (B, max_len) 的mask,只有在[start_frame, end_frame]区间内为True,其余为False
40
+ assert (start_frame >= 0).all() and (end_frame >= 0).all(), f"start_frame={start_frame}, end_frame={end_frame}"
41
+ lengths = end_frame - start_frame + 1
42
+ assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}"
43
+ if lengths.ndim == 1:
44
+ lengths = lengths.unsqueeze(1)
45
+ batch_size = start_frame.shape[0]
46
+ arange_ids = torch.arange(max_len, device=start_frame.device).unsqueeze(0).expand(batch_size, max_len)
47
+ mask = (arange_ids >= start_frame.unsqueeze(1)) & (arange_ids <= end_frame.unsqueeze(1))
48
+ return mask
49
+
50
+
51
+ def randn_tensor(
52
+ shape,
53
+ generator=None,
54
+ device=None,
55
+ dtype=None,
56
+ layout=None,
57
+ ):
58
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`.
59
+
60
+ When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the
61
+ tensor is always created on the CPU.
62
+ """
63
+ # device on which tensor is created defaults to device
64
+ rand_device = device
65
+ batch_size = shape[0]
66
+
67
+ layout = layout or torch.strided
68
+ device = device or torch.device("cpu")
69
+
70
+ if generator is not None:
71
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
72
+ if gen_device_type != device.type and gen_device_type == "cpu":
73
+ rand_device = "cpu"
74
+ if device != "mps":
75
+ print(
76
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
77
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
78
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
79
+ )
80
+ elif gen_device_type != device.type and gen_device_type == "cuda":
81
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
82
+
83
+ # make sure generator list of length 1 is treated like a non-list
84
+ if isinstance(generator, list) and len(generator) == 1:
85
+ generator = generator[0]
86
+
87
+ if isinstance(generator, list):
88
+ shape = (1,) + shape[1:]
89
+ latents = [
90
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
91
+ for i in range(batch_size)
92
+ ]
93
+ latents = torch.cat(latents, dim=0).to(device)
94
+ else:
95
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
96
+
97
+ return latents
98
+
99
+
100
+ class MotionGeneration(torch.nn.Module):
101
+ def __init__(
102
+ self,
103
+ network_module: str,
104
+ network_module_args: dict,
105
+ text_encoder_module: str,
106
+ text_encoder_cfg: dict,
107
+ mean_std_dir: str,
108
+ motion_type="auto",
109
+ **kwargs,
110
+ ):
111
+ super().__init__()
112
+ # build models and parameters
113
+ self._network_module_args = deepcopy(network_module_args)
114
+ self.motion_transformer = load_object(network_module, network_module_args)
115
+ self._text_encoder_module = text_encoder_module
116
+ self._text_encoder_cfg = deepcopy(text_encoder_cfg)
117
+ self.motion_type = motion_type
118
+
119
+ self.null_vtxt_feat = torch.nn.Parameter(
120
+ torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
121
+ )
122
+ self.null_ctxt_input = torch.nn.Parameter(
123
+ torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
124
+ )
125
+ self.special_game_vtxt_feat = torch.nn.Parameter(
126
+ torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768))
127
+ )
128
+ self.special_game_ctxt_feat = torch.nn.Parameter(
129
+ torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096))
130
+ )
131
+ # build buffer
132
+ self.mean_std_dir = mean_std_dir
133
+ self._parse_buffer(self.motion_type)
134
+
135
+ self.output_mesh_fps = kwargs.get("output_mesh_fps", 30)
136
+ self.train_frames = kwargs.get("train_frames", 360)
137
+ self.uncondition_mode = kwargs.get("uncondition_mode", False)
138
+ self.enable_ctxt_null_feat = kwargs.get("enable_ctxt_null_feat", False)
139
+ self.enable_special_game_feat = kwargs.get("enable_special_game_feat", False)
140
+ self.random_generator_on_gpu = kwargs.get("random_generator_on_gpu", True)
141
+
142
+ def _parse_buffer(self, mode: str) -> None:
143
+ self.body_model = WoodenMesh()
144
+ self._find_motion_type(mode=mode)
145
+ self._load_mean_std()
146
+
147
+ def _load_mean_std(self, mean_std_name: Optional[str] = None) -> None:
148
+ mean_std_name = self.mean_std_dir if mean_std_name is None else mean_std_name
149
+ if mean_std_name is not None and osp.isdir(mean_std_name):
150
+ mean = torch.from_numpy(np.load(osp.join(mean_std_name, "Mean.npy"))).float()
151
+ std = torch.from_numpy(np.load(osp.join(mean_std_name, "Std.npy"))).float()
152
+ self._assert_motion_dimension(mean.unsqueeze(0), std.unsqueeze(0))
153
+ self.register_buffer("mean", mean)
154
+ self.register_buffer("std", std)
155
+ else:
156
+ print(
157
+ f"[{self.__class__.__name__}] No mean_std found, using blank mean_std, "
158
+ f"self.mean_std_dir={self.mean_std_dir}"
159
+ )
160
+ self.register_buffer("mean", torch.zeros(1))
161
+ self.register_buffer("std", torch.ones(1))
162
+
163
+ def _assert_motion_dimension(self, mean: Tensor, std: Tensor) -> None:
164
+ assert mean.shape == std.shape, f"mean.shape={mean.shape} != std.shape={std.shape}"
165
+ assert mean.ndim == 2, f"mean.ndim={mean.ndim} != 2"
166
+ assert mean.shape == (1, 201), f"mean.shape={mean.shape} != (1, 201)"
167
+
168
+ def _find_motion_type(self, mode: str) -> None:
169
+ if mode == "auto":
170
+ self.motion_type = "o6dp"
171
+ else:
172
+ self.motion_type = mode
173
+
174
+ def set_epoch(self, epoch) -> None:
175
+ self.current_epoch = epoch
176
+
177
+ def load_in_demo(
178
+ self,
179
+ ckpt_name: str,
180
+ mean_std_name: Optional[str] = None,
181
+ build_text_encoder: bool = True,
182
+ allow_empty_ckpt: bool = False,
183
+ ) -> None:
184
+ if not allow_empty_ckpt:
185
+ if not os.path.exists(ckpt_name):
186
+ import warnings
187
+ warnings.warn(f"Checkpoint {ckpt_name} not found, skipping model loading")
188
+ else:
189
+ checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
190
+ self.load_state_dict(checkpoint["model_state_dict"], strict=False)
191
+ if mean_std_name is not None:
192
+ assert os.path.exists(mean_std_name), f"{mean_std_name} not found"
193
+ if not os.path.isfile(mean_std_name):
194
+ mean_std_name = None
195
+ self._load_mean_std(mean_std_name)
196
+ self.motion_transformer.eval()
197
+ if build_text_encoder and not self.uncondition_mode:
198
+ self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
199
+ self.text_encoder.to(get_module_device(self))
200
+
201
+ @torch.no_grad()
202
+ def encode_text(self, text: Dict[str, List[str]]) -> Dict[str, Tensor]:
203
+ if not hasattr(self, "text_encoder"):
204
+ self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
205
+ self.text_encoder.to(get_module_device(self))
206
+ text = text["text"]
207
+ vtxt_input, ctxt_input, ctxt_length = self.text_encoder.encode(text=text)
208
+ return {
209
+ "text_vec_raw": vtxt_input,
210
+ "text_ctxt_raw": ctxt_input,
211
+ "text_ctxt_raw_length": ctxt_length,
212
+ }
213
+
214
+ def decode_motion_from_latent(self, latent: Tensor, should_apply_smooothing: bool = True) -> Dict[str, Tensor]:
215
+ std_zero = self.std < 1e-3
216
+ std = torch.where(std_zero, torch.zeros_like(self.std), self.std)
217
+ latent_denorm = latent * std + self.mean
218
+ return self._decode_o6dp(
219
+ latent_denorm,
220
+ num_joints=22,
221
+ rel_trans=False,
222
+ should_apply_smooothing=should_apply_smooothing,
223
+ )
224
+
225
+ def _forward_smpl_batch(
226
+ self,
227
+ root_rot6d: Tensor, # (B, L, 1, 6)
228
+ body_rot6d: Tensor, # (B, L, 21, 6)
229
+ transl: Tensor, # (B, L, 3)
230
+ left_hand_pose: Optional[Tensor] = None, # (B, L, 15, 6)
231
+ right_hand_pose: Optional[Tensor] = None, # (B, L, 16, 6)
232
+ ) -> Tensor:
233
+ device = transl.device
234
+ bsz, L = transl.shape[:2]
235
+ k3d_all = []
236
+ tmp_betas = torch.zeros(1, 16, device=device)
237
+ for bs in range(bsz):
238
+ out = self.body_model(
239
+ body_rot6d[bs],
240
+ tmp_betas,
241
+ root_rot6d[bs],
242
+ transl[bs],
243
+ left_hand_pose=(left_hand_pose[bs] if left_hand_pose is not None else None),
244
+ right_hand_pose=(right_hand_pose[bs] if right_hand_pose is not None else None),
245
+ )
246
+ k3d_all.append(out.detach().cpu())
247
+ return torch.stack(k3d_all, dim=0) # (B, L, J, 3)
248
+
249
+ def _decode_o6dp(
250
+ self,
251
+ latent_denorm: torch.Tensor,
252
+ num_joints: int,
253
+ rel_trans: bool = False,
254
+ should_apply_smooothing: bool = True,
255
+ ) -> dict:
256
+ device = get_module_device(self)
257
+ B, L = latent_denorm.shape[:2]
258
+ nj = num_joints
259
+ body_n = nj - 1
260
+
261
+ if not rel_trans:
262
+ transl = latent_denorm[..., 0:3].clone()
263
+ else:
264
+ transl = torch.cumsum(latent_denorm[..., 0:3].clone(), dim=1) / self.output_mesh_fps
265
+ root_rot6d = latent_denorm[..., 3:9].reshape(B, L, 1, 6).clone()
266
+
267
+ body6d_start = 9
268
+ body6d_end = body6d_start + body_n * 6
269
+ body_rot6d_full = latent_denorm[..., body6d_start:body6d_end].clone().reshape(B, L, body_n, 6)
270
+
271
+ # 52 joints need to be split into hands
272
+ left_hand_pose = right_hand_pose = None
273
+ if nj == 52:
274
+ body_rot6d = body_rot6d_full[:, :, :21, :].clone()
275
+ left_hand_pose = body_rot6d_full[:, :, 21:36, :].clone()
276
+ right_hand_pose = body_rot6d_full[:, :, 36:51, :].clone()
277
+ else:
278
+ body_rot6d = body_rot6d_full
279
+
280
+ if left_hand_pose is not None and right_hand_pose is not None:
281
+ body_full = torch.cat([body_rot6d, left_hand_pose, right_hand_pose], dim=2)
282
+ else:
283
+ body_full = body_rot6d
284
+ rot6d = torch.cat([root_rot6d, body_full], dim=2) # (B, L, nj, 6)
285
+ if should_apply_smooothing:
286
+ # only apply slerp smoothing to the first 22 joints (non-finger joints)
287
+ rot6d_body = rot6d[:, :, :22, :] # (B, L, 22, 6)
288
+ rot6d_fingers = rot6d[:, :, 22:, :] # (B, L, J-22, 6)
289
+ rot6d_body_smooth = self.smooth_with_slerp(rot6d_body, sigma=1.0)
290
+ rot6d_smooth = torch.cat([rot6d_body_smooth, rot6d_fingers], dim=2)
291
+ else:
292
+ rot6d_smooth = rot6d
293
+ root_rotmat_smooth = rot6d_to_rotation_matrix(rot6d_smooth[:, :, 0, :]) # (B, L, 3, 3)
294
+
295
+ transl_fixed = transl.detach()
296
+ if should_apply_smooothing:
297
+ transl_smooth = self.smooth_with_savgol(transl_fixed.detach(), window_length=11, polyorder=5)
298
+ else:
299
+ transl_smooth = transl_fixed
300
+
301
+ if self.body_model is not None:
302
+ print(f'{self.__class__.__name__} rot6d_smooth shape: {rot6d_smooth.shape}, transl_smooth shape: {transl_smooth.shape}')
303
+ with torch.no_grad():
304
+ vertices_all = []
305
+ k3d_all = []
306
+ for bs in range(rot6d_smooth.shape[0]):
307
+ out = self.body_model.forward(
308
+ {
309
+ 'rot6d': rot6d_smooth[bs],
310
+ 'trans': transl_smooth[bs],
311
+ }
312
+ )
313
+ vertices_all.append(out["vertices"])
314
+ k3d_all.append(out['keypoints3d'])
315
+ vertices = torch.stack(vertices_all, dim=0)
316
+ k3d = torch.stack(k3d_all, dim=0)
317
+ print(f'{self.__class__.__name__} vertices shape: {vertices.shape}, k3d shape: {k3d.shape}')
318
+ # k3d = self._forward_smpl_batch(
319
+ # rot6d_smooth[:, :, 0:1, :].to(device),
320
+ # rot6d_smooth[:, :, 1:22, :].to(device),
321
+ # transl_smooth,
322
+ # left_hand_pose=(rot6d_smooth[:, :, 22:37, :].to(device) if left_hand_pose is not None else None),
323
+ # right_hand_pose=(rot6d_smooth[:, :, 37:52, :].to(device) if right_hand_pose is not None else None),
324
+ # )
325
+ # align with the ground
326
+ min_y = vertices[..., 1].amin(dim=(1, 2), keepdim=True) # (B, 1, 1)
327
+ print(f'{self.__class__.__name__} min_y: {min_y}')
328
+ k3d = k3d.clone()
329
+ k3d[..., 1] -= min_y # (B, L, J) - (B, 1, 1)
330
+ transl_smooth = transl_smooth.clone()
331
+ transl_smooth[..., 1] -= min_y.squeeze(-1).to(device) # (B, L) - (B, 1)
332
+ else:
333
+ k3d = torch.zeros(B, L, nj, 3, device=device)
334
+
335
+ return dict(
336
+ latent_denorm=latent_denorm, # (B, L, 201)
337
+ keypoints3d=k3d, # (B, L, J, 3)
338
+ rot6d=rot6d_smooth, # (B, L, J, 6)
339
+ transl=transl_smooth, # (B, L, 3)
340
+ root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3)
341
+ )
342
+
343
+ @staticmethod
344
+ def smooth_with_savgol(input: torch.Tensor, window_length: int = 9, polyorder: int = 5) -> torch.Tensor:
345
+ if len(input.shape) == 2:
346
+ is_batch = False
347
+ input = input.unsqueeze(0)
348
+ else:
349
+ is_batch = True
350
+ input_np = input.cpu().numpy()
351
+ input_smooth_np = np.empty_like(input_np, dtype=np.float32)
352
+ for b in range(input_np.shape[0]):
353
+ for j in range(input_np.shape[2]):
354
+ input_smooth_np[b, :, j] = savgol_filter(input_np[b, :, j], window_length, polyorder)
355
+ input_smooth = torch.from_numpy(input_smooth_np).to(input)
356
+ if not is_batch:
357
+ input_smooth = input_smooth.squeeze(0)
358
+ return input_smooth
359
+
360
+ @staticmethod
361
+ def smooth_with_slerp(input: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
362
+ def fix_time_continuity(q: Tensor, time_dim: int = -3):
363
+ shape = q.shape
364
+ qv = q.moveaxis(time_dim, 0).contiguous().view(shape[time_dim], -1, 4)
365
+ qv = quaternion_fix_continuity(qv)
366
+ return qv.view(shape[time_dim], *shape[:time_dim], *shape[time_dim + 1 :]).moveaxis(0, time_dim)
367
+
368
+ num_joints = input.shape[2]
369
+ RR = rot6d_to_rotation_matrix(input)
370
+ qq = matrix_to_quaternion(RR)
371
+ qq_np = fix_time_continuity(qq, time_dim=1).cpu().numpy()
372
+ qq_s_np = smooth_rotation(
373
+ qq_np,
374
+ sigma=sigma,
375
+ )
376
+ input_smooth = rotation_matrix_to_rot6d(quaternion_to_matrix(torch.from_numpy(qq_s_np)))
377
+ return input_smooth.to(input.device)
378
+
379
+ @staticmethod
380
+ def noise_from_seeds(
381
+ latent: Tensor, seeds: Union[int, List[int]], seed_start: int = 0, random_generator_on_gpu: bool = True
382
+ ) -> Tensor:
383
+ if isinstance(seeds, int):
384
+ seeds = list(range(seeds))
385
+ noise_list = []
386
+ B = latent.shape[0]
387
+ shape = (B, *latent.shape[1:])
388
+ for seed in seeds:
389
+ if random_generator_on_gpu:
390
+ generator = torch.Generator(device=latent.device).manual_seed(seed + seed_start)
391
+ noise_sample = randn_tensor(shape, generator=generator, device=latent.device, dtype=latent.dtype)
392
+ else:
393
+ generator = torch.Generator().manual_seed(seed + seed_start)
394
+ noise_sample = randn_tensor(shape, generator=generator, dtype=latent.dtype).to(latent.device)
395
+ noise_list.append(noise_sample)
396
+ return torch.cat(noise_list, dim=0)
397
+
398
+ def _maybe_inject_source_token(
399
+ self,
400
+ vtxt_input: Tensor,
401
+ ctxt_input: Tensor,
402
+ ctxt_mask_temporal: Tensor,
403
+ sources: Optional[List[str]],
404
+ trigger_sources: Optional[set] = None,
405
+ prob: float = 0.5,
406
+ ) -> Tuple[Tensor, Tensor, Tensor]:
407
+ if (sources is None or trigger_sources is None) or not self.enable_special_game_feat:
408
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
409
+
410
+ B, Lc, Dc = ctxt_input.shape
411
+ assert (
412
+ isinstance(sources, (list, tuple)) and len(sources) == B
413
+ ), f"sources length should be equal to batch: {len(sources)} vs {B}"
414
+
415
+ trig = set(s.lower() for s in trigger_sources)
416
+ src_mask = torch.tensor(
417
+ [str(s).lower() in trig for s in sources], dtype=torch.bool, device=ctxt_input.device
418
+ ) # (B,)
419
+ if not src_mask.any():
420
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
421
+
422
+ rand_mask = (
423
+ torch.rand(B, device=ctxt_input.device) < prob
424
+ if self.training
425
+ else torch.BoolTensor(B).fill_(True).to(ctxt_input.device)
426
+ )
427
+ apply_mask = src_mask & rand_mask
428
+ if not apply_mask.any():
429
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
430
+
431
+ # vtxt: only add mixture to the hit samples
432
+ vtxt_token = self.special_game_vtxt_feat.to(vtxt_input).expand(B, 1, -1)
433
+ vtxt_input = vtxt_input + vtxt_token * apply_mask.view(B, 1, 1).to(vtxt_input.dtype)
434
+
435
+ # calculate the current effective length of each sample
436
+ if ctxt_mask_temporal.dtype == torch.bool:
437
+ cur_len = ctxt_mask_temporal.sum(dim=1).long() # (B,)
438
+ else:
439
+ cur_len = (ctxt_mask_temporal > 0).sum(dim=1).long()
440
+
441
+ # for the "not full" hit samples,
442
+ # write the special token at the cur_len position,
443
+ # and set the mask to True
444
+ can_inplace = apply_mask & (cur_len < Lc)
445
+ b_inplace = torch.nonzero(can_inplace, as_tuple=False).squeeze(1) # (K,)
446
+ if b_inplace.numel() > 0:
447
+ pos = cur_len[b_inplace] # (K,)
448
+ token = self.special_game_ctxt_feat.squeeze(0).squeeze(0).to(ctxt_input) # (Dc,)
449
+ ctxt_input[b_inplace, pos, :] = token.unsqueeze(0).expand(b_inplace.numel(), Dc)
450
+ if ctxt_mask_temporal.dtype == torch.bool:
451
+ ctxt_mask_temporal[b_inplace, pos] = True
452
+ else:
453
+ ctxt_mask_temporal[b_inplace, pos] = 1
454
+
455
+ # if there are "full" hit samples, need to pad one:
456
+ # the full samples write the special token at the new position,
457
+ # other samples pad zero and mask=False
458
+ need_expand = (apply_mask & (cur_len >= Lc)).any()
459
+ if need_expand:
460
+ suffix = torch.zeros((B, 1, Dc), dtype=ctxt_input.dtype, device=ctxt_input.device)
461
+ full_hit = apply_mask & (cur_len >= Lc)
462
+ b_full = torch.nonzero(full_hit, as_tuple=False).squeeze(1)
463
+ if b_full.numel() > 0:
464
+ suffix[b_full, 0, :] = (
465
+ self.special_game_ctxt_feat.expand(b_full.numel(), 1, -1).to(ctxt_input).squeeze(1)
466
+ )
467
+ ctxt_input = torch.cat([ctxt_input, suffix], dim=1)
468
+
469
+ if ctxt_mask_temporal.dtype == torch.bool:
470
+ suffix_mask = torch.zeros((B, 1), dtype=torch.bool, device=ctxt_input.device)
471
+ suffix_mask[b_full, 0] = True
472
+ else:
473
+ suffix_mask = torch.zeros((B, 1), dtype=ctxt_mask_temporal.dtype, device=ctxt_input.device)
474
+ suffix_mask[b_full, 0] = 1
475
+ ctxt_mask_temporal = torch.cat([ctxt_mask_temporal, suffix_mask], dim=1)
476
+
477
+ return vtxt_input, ctxt_input, ctxt_mask_temporal
478
+
479
+
480
+ class MotionFlowMatching(MotionGeneration):
481
+ def __init__(
482
+ self,
483
+ network_module: str,
484
+ network_module_args: dict,
485
+ text_encoder_module: str,
486
+ text_encoder_cfg: dict,
487
+ noise_scheduler_cfg: dict = {"method": "euler"},
488
+ infer_noise_scheduler_cfg: dict = {"validation_steps": 50},
489
+ mean_std_dir: Optional[str] = None,
490
+ losses_cfg: Optional[dict] = None,
491
+ train_cfg: Optional[dict] = None,
492
+ test_cfg: Optional[dict] = None,
493
+ **kwargs,
494
+ ):
495
+ super().__init__(
496
+ network_module=network_module,
497
+ network_module_args=network_module_args,
498
+ text_encoder_module=text_encoder_module,
499
+ text_encoder_cfg=text_encoder_cfg,
500
+ losses_cfg=losses_cfg,
501
+ mean_std_dir=(mean_std_dir if mean_std_dir is not None else test_cfg.get("mean_std_dir", None)),
502
+ **kwargs,
503
+ )
504
+ # build scheduler
505
+ self._noise_scheduler_cfg = deepcopy(noise_scheduler_cfg)
506
+ self._infer_noise_scheduler_cfg = deepcopy(infer_noise_scheduler_cfg)
507
+ # additional cfg
508
+ self.train_cfg = deepcopy(train_cfg) if train_cfg else dict()
509
+ self.test_cfg = deepcopy(test_cfg) if test_cfg else dict()
510
+ self._parse_test_cfg()
511
+
512
+ def _parse_test_cfg(self) -> None:
513
+ self.validation_steps = self._infer_noise_scheduler_cfg["validation_steps"]
514
+ self.text_guidance_scale = self.test_cfg.get("text_guidance_scale", 1)
515
+
516
+ @torch.no_grad()
517
+ def generate(
518
+ self,
519
+ text: Union[str, List[str]],
520
+ seed_input: List[int],
521
+ duration_slider: int,
522
+ cfg_scale: Optional[float] = None,
523
+ use_special_game_feat: bool = False,
524
+ hidden_state_dict=None,
525
+ length=None,
526
+ ) -> Dict[str, Any]:
527
+ device = get_module_device(self)
528
+ if length is None:
529
+ length = int(round(duration_slider * self.output_mesh_fps))
530
+ assert (
531
+ 0 < length < 5000
532
+ ), f"input duration_slider must be in (0, {5000/self.output_mesh_fps}] due to rope, but got {duration_slider}"
533
+ if length > self.train_frames or length < min(self.train_frames, 20):
534
+ print(f">>> given length is too long or too short, got {length}, will be truncated")
535
+ length = min(length, self.train_frames)
536
+ length = max(length, min(self.train_frames, 20))
537
+
538
+ repeat = len(seed_input)
539
+ if isinstance(text, list):
540
+ assert len(text) == repeat, f"len(text) must equal len(seed_input), got {len(text)} vs {repeat}"
541
+ text_list = text
542
+ elif isinstance(text, str):
543
+ text_list = [text] * repeat
544
+ else:
545
+ raise TypeError(f"Unsupported text type: {type(text)}")
546
+
547
+ if not self.uncondition_mode:
548
+ if hidden_state_dict is None:
549
+ hidden_state_dict = self.encode_text({"text": text_list})
550
+ vtxt_input = hidden_state_dict["text_vec_raw"]
551
+ ctxt_input = hidden_state_dict["text_ctxt_raw"]
552
+ ctxt_length = hidden_state_dict["text_ctxt_raw_length"]
553
+ # check shape
554
+ if len(vtxt_input.shape) == 2 and len(ctxt_input.shape) == 2:
555
+ vtxt_input = vtxt_input[None].repeat(repeat, 1, 1)
556
+ ctxt_input = ctxt_input[None].repeat(repeat, 1, 1)
557
+ ctxt_length = ctxt_length.repeat(repeat)
558
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
559
+ sources = None if not use_special_game_feat else ["Game"] * repeat
560
+ vtxt_input, ctxt_input, ctxt_mask_temporal = self._maybe_inject_source_token(
561
+ vtxt_input, ctxt_input, ctxt_mask_temporal, sources, trigger_sources={"Taobao", "Game"}
562
+ )
563
+ else:
564
+ vtxt_input = self.null_vtxt_feat.expand(repeat, 1, -1)
565
+ ctxt_input = self.null_ctxt_input.expand(repeat, 1, -1)
566
+ ctxt_length = torch.tensor([1]).expand(repeat)
567
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]).expand(repeat, -1)
568
+ assert len(vtxt_input.shape) == 3, f"vtxt_input.shape: {vtxt_input.shape}, should be (B, 1, D)"
569
+ assert len(ctxt_input.shape) == 3, f"ctxt_input.shape: {ctxt_input.shape}, should be (B, 1, D)"
570
+ assert len(ctxt_length.shape) == 1, f"ctxt_length.shape: {ctxt_length.shape}, should be (B,)"
571
+
572
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
573
+ x_length = torch.LongTensor([length] * repeat).to(device)
574
+ x_mask_temporal = length_to_mask(x_length, self.train_frames)
575
+
576
+ text_guidance_scale = cfg_scale if cfg_scale is not None else self.text_guidance_scale
577
+ do_classifier_free_guidance = text_guidance_scale > 1.0 and not self.uncondition_mode
578
+ if do_classifier_free_guidance is True:
579
+ silent_text_feat = self.null_vtxt_feat.expand(*vtxt_input.shape)
580
+ vtxt_input = torch.cat([silent_text_feat, vtxt_input], dim=0)
581
+
582
+ if self.enable_ctxt_null_feat:
583
+ silent_ctxt_input = self.null_ctxt_input.expand(*ctxt_input.shape)
584
+ else:
585
+ silent_ctxt_input = ctxt_input
586
+ ctxt_input = torch.cat([silent_ctxt_input, ctxt_input], dim=0)
587
+
588
+ ctxt_mask_temporal = torch.cat([ctxt_mask_temporal] * 2, dim=0)
589
+ x_mask_temporal = torch.cat([x_mask_temporal] * 2, dim=0)
590
+
591
+ def fn(t: Tensor, x: Tensor) -> Tensor:
592
+ # predict flow
593
+ x_input = torch.cat([x] * 2, dim=0) if do_classifier_free_guidance else x
594
+ x_pred = self.motion_transformer(
595
+ x=x_input,
596
+ ctxt_input=ctxt_input,
597
+ vtxt_input=vtxt_input,
598
+ timesteps=t.expand(x_input.shape[0]),
599
+ x_mask_temporal=x_mask_temporal,
600
+ ctxt_mask_temporal=ctxt_mask_temporal,
601
+ )
602
+ if do_classifier_free_guidance:
603
+ x_pred_basic, x_pred_text = x_pred.chunk(2, dim=0)
604
+ x_pred = x_pred_basic + text_guidance_scale * (x_pred_text - x_pred_basic)
605
+ return x_pred
606
+
607
+ # duplicate test corner for inner time step oberservation
608
+ t = torch.linspace(0, 1, self.validation_steps + 1, device=device)
609
+ y0 = self.noise_from_seeds(
610
+ torch.zeros(
611
+ 1,
612
+ self.train_frames,
613
+ self._network_module_args["input_dim"],
614
+ device=device,
615
+ ),
616
+ seed_input,
617
+ random_generator_on_gpu=self.random_generator_on_gpu,
618
+ )
619
+ with torch.no_grad():
620
+ trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
621
+ sampled = trajectory[-1]
622
+ assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
623
+ sampled = sampled[:, :length, ...].clone()
624
+
625
+ output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
626
+
627
+ return {
628
+ **output_dict,
629
+ "text": text,
630
+ }
631
+
632
+
633
+ if __name__ == "__main__":
634
+ # python -m hymotion.pipeline.motion_diffusion
635
+ import time
636
+
637
+ import torch
638
+
639
+ device = "cuda:0"
640
+ bsz, input_dim = 64, 272
641
+ seq_lens = [90, 180, 360]
642
+ ctxt_seq_lens = 64
643
+ warmup = 5
644
+ repeats = 100
645
+
646
+ network_module = "hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT"
647
+ network_module_args = {
648
+ "input_dim": input_dim,
649
+ "feat_dim": 512,
650
+ "ctxt_input_dim": 4096,
651
+ "vtxt_input_dim": 768,
652
+ "num_layers": 12,
653
+ "num_heads": 4,
654
+ "mlp_ratio": 2.0,
655
+ "dropout": 0.0,
656
+ "mask_mode": "narrowband",
657
+ }
658
+ text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
659
+ text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": ctxt_seq_lens}
660
+
661
+ # ================================ FM_MMDiT ================================
662
+ FM_MMDiT = MotionFlowMatching(
663
+ network_module=network_module,
664
+ network_module_args=network_module_args,
665
+ text_encoder_module=text_encoder_module,
666
+ text_encoder_cfg=text_encoder_cfg,
667
+ noise_scheduler_module={"method": "euler"},
668
+ infer_noise_scheduler_cfg={"validation_steps": 50},
669
+ train_cfg={"cond_mask_prob": 0.1},
670
+ test_cfg={
671
+ "text_guidance_scale": 1.5,
672
+ },
673
+ ).to(device)
hymotion/prompt_engineering/model_constants.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "REWRITE_AND_INFER_TIME_PROMPT_FORMAT",
3
+ ]
4
+
5
+ REWRITE_AND_INFER_TIME_PROMPT_FORMAT = """
6
+ # Role
7
+ You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems.
8
+
9
+ # Task
10
+ Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption.
11
+
12
+ # Instructions
13
+
14
+ ### 1. Duration Estimation (frame_count)
15
+ - Analyze the complexity, speed, and physical constraints of the described action.
16
+ - Estimate the time required to perform the action in a **smooth, natural, and realistic manner**.
17
+ - Calculate the total duration in frames based on a **30 fps** (frames per second) standard.
18
+ - Output strictly as an Integer.
19
+
20
+ ### 2. Caption Refinement (short_caption)
21
+ - Generate a refined, grammatically correct version of the input description in **English**.
22
+ - **Strict Constraints**:
23
+ - You must **PRESERVE** the original sequence of events (chronological order).
24
+ - You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly").
25
+ - **DO NOT** add new sub-actions or hallucinate details not present in the input.
26
+ - **DO NOT** delete any specific movements.
27
+ - The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request.
28
+
29
+ ### 3. Output Format
30
+ - Return **ONLY** a raw JSON object.
31
+ - Do not use Markdown formatting (i.e., do not use ```json ... ```).
32
+ - Ensure the JSON is valid and parsable.
33
+
34
+ # JSON Structure
35
+ {{
36
+ "duration": <Integer, frames at 30fps>,
37
+ "short_caption": "<String, the refined English description>"
38
+ }}
39
+
40
+ # Input
41
+ {}
42
+ """
hymotion/prompt_engineering/prompt_rewrite.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prompt_rewrite.py
2
+ import base64
3
+ import concurrent.futures
4
+ import datetime
5
+ import hashlib
6
+ import hmac
7
+ import json
8
+ import logging
9
+ import random
10
+ import re
11
+ import time
12
+ import uuid
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
15
+
16
+ from openai import OpenAI
17
+ from requests import exceptions as req_exc
18
+
19
+ from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT
20
+
21
+ # logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ @dataclass
25
+ class ApiConfig:
26
+ host: str
27
+ user: str
28
+ apikey: str
29
+ model: str
30
+ api_version: Optional[str] = None
31
+ timeout: int = 3600
32
+ source: str = "hymotion"
33
+
34
+
35
+ @dataclass
36
+ class RetryConfig:
37
+ max_retries: int = 20
38
+ base_delay: float = 1.0
39
+ timeout: float = 30.0
40
+ retry_status: Tuple[int, ...] = (429, 500, 502, 503, 504)
41
+ max_delay: float = 1.0
42
+
43
+
44
+ class ApiError(Exception):
45
+ pass
46
+
47
+
48
+ class ResponseParseError(Exception):
49
+ pass
50
+
51
+
52
+ class OpenAIChatApi:
53
+ def __init__(self, config: ApiConfig) -> None:
54
+ self.logger = logging.getLogger(__name__)
55
+ self.config = config
56
+ self.client = OpenAI(
57
+ api_key=self.config.apikey,
58
+ base_url=self.config.host,
59
+ )
60
+
61
+ def call_data_eval(self, data: Union[str, Dict[str, Any]]):
62
+ if isinstance(data, dict) and "messages" in data:
63
+ raw_msgs = data["messages"]
64
+ messages: List[Dict[str, str]] = []
65
+ for m in raw_msgs:
66
+ role = m.get("role", "user")
67
+ content = m.get("content", "")
68
+ if isinstance(content, list):
69
+ parts = []
70
+ for p in content:
71
+ if isinstance(p, dict) and ("text" in p):
72
+ parts.append(str(p.get("text", "")))
73
+ content = " ".join([t for t in parts if t])
74
+ elif not isinstance(content, str):
75
+ content = str(content)
76
+ messages.append({"role": role, "content": content})
77
+ payload = {
78
+ "model": self.config.model,
79
+ "messages": messages,
80
+ "temperature": 0.7,
81
+ "top_p": 0.8,
82
+ }
83
+ for k in (
84
+ "temperature",
85
+ "top_p",
86
+ "max_tokens",
87
+ "n",
88
+ "stop",
89
+ "presence_penalty",
90
+ "frequency_penalty",
91
+ "user",
92
+ ):
93
+ if k in data:
94
+ payload[k] = data[k]
95
+ else:
96
+ payload = {
97
+ "model": self.config.model,
98
+ "messages": [{"role": "user", "content": str(data)}],
99
+ "temperature": 0.7,
100
+ "top_p": 0.8,
101
+ }
102
+ try:
103
+ resp = self.client.chat.completions.create(**payload)
104
+ return resp
105
+ except Exception as e:
106
+ self.logger.error(f"OpenAI API call failed: {e}")
107
+ raise ApiError(f"OpenAI API call failed: {e}") from e
108
+
109
+
110
+ class ResponseParser:
111
+ def __init__(self):
112
+ self.logger = logging.getLogger(__name__)
113
+
114
+ def call_data_eval_with_retry(
115
+ self, api: Union[OpenAIChatApi], data: str, retry_config: Optional[RetryConfig] = None
116
+ ) -> Tuple[Union[Dict[str, Any], int], float, float]:
117
+ if retry_config is None:
118
+ retry_config = RetryConfig()
119
+
120
+ last_error = None
121
+ for attempt in range(retry_config.max_retries):
122
+ start_time = time.time()
123
+ cost = 0.0
124
+
125
+ try:
126
+ result = self._execute_request(api, data)
127
+ end_time = time.time()
128
+ parsed_result = self._parse_answer(result)
129
+ self._validate_result(parsed_result)
130
+ return parsed_result, cost, end_time - start_time
131
+
132
+ except (
133
+ concurrent.futures.TimeoutError,
134
+ req_exc.RequestException,
135
+ json.JSONDecodeError,
136
+ ValueError,
137
+ TypeError,
138
+ ResponseParseError,
139
+ ) as e:
140
+ last_error = e
141
+ self.logger.warning(f"Attempt {attempt + 1} failed: {e}")
142
+ if isinstance(e, req_exc.RequestException) and hasattr(e, "response"):
143
+ if e.response is not None and e.response.status_code not in retry_config.retry_status:
144
+ raise ApiError(f"Non-retryable error: {e.response.status_code}") from e
145
+ if attempt < retry_config.max_retries - 1:
146
+ delay = self._calculate_delay(attempt, retry_config)
147
+ self.logger.info(f"JSON parsing failed, {delay:.1f} seconds later retry...")
148
+ time.sleep(delay)
149
+
150
+ raise ApiError(f"Retry {retry_config.max_retries} times but still failed") from last_error
151
+
152
+ def _execute_request(self, api: Union[OpenAIChatApi], data: str) -> Dict[str, Any]:
153
+ response = api.call_data_eval(data)
154
+
155
+ try:
156
+ if hasattr(response, "model_dump"):
157
+ return response.model_dump()
158
+ if isinstance(response, dict):
159
+ return response
160
+ if hasattr(response, "__dict__"):
161
+ return json.loads(json.dumps(response.__dict__, default=str))
162
+ except Exception as e:
163
+ raise ResponseParseError(f"Unable to parse OpenAI returned object: {type(response)} - {e}") from e
164
+
165
+ raise ResponseParseError(f"Unknown response type: {type(response)}")
166
+
167
+ def _extract_cost(self, payload: Dict[str, Any]) -> float:
168
+ try:
169
+ return float(payload.get("cost_info", {}).get("cost", 0)) / 1e6
170
+ except (AttributeError, KeyError):
171
+ return 0.0
172
+
173
+ def _validate_result(self, result: Union[Dict[str, Any], int]) -> None:
174
+ if isinstance(result, int):
175
+ return
176
+ elif isinstance(result, dict):
177
+ required_fields = ["duration", "short_caption"]
178
+ for field in required_fields:
179
+ if not isinstance(result.get(field), (int, str)):
180
+ raise ResponseParseError(f"LLM returned invalid format: {field}")
181
+ else:
182
+ raise ResponseParseError(f"Unsupported answer type: {type(result)}")
183
+
184
+ def _calculate_delay(self, attempt: int, config: RetryConfig) -> float:
185
+ delay = config.base_delay * (2**attempt) * (0.5 + random.random())
186
+ return min(delay, config.max_delay)
187
+
188
+ def _parse_answer(self, payload: Dict[str, Any]) -> Dict[str, Any]:
189
+ if isinstance(payload, dict) and "choices" in payload:
190
+ return self._parse_from_choices_field(payload)
191
+
192
+ raise ResponseParseError("Unknown response format: expected choices")
193
+
194
+ def _parse_from_choices_field(self, payload: Dict[str, Any]) -> Dict[str, Any]:
195
+ choices = payload.get("choices") or []
196
+ if not choices:
197
+ raise ResponseParseError("OpenAI returned empty")
198
+
199
+ content = self._extract_content_from_choice(choices[0])
200
+
201
+ if not isinstance(content, str) or not content.strip():
202
+ raise ResponseParseError("OpenAI returned no valid content")
203
+
204
+ return self._parse_json_content(content)
205
+
206
+ def _extract_content_from_choice(self, choice: Any) -> Optional[str]:
207
+ content = None
208
+
209
+ if isinstance(choice, dict):
210
+ # Try message content first
211
+ msg = choice.get("message") or {}
212
+ content = msg.get("content")
213
+ # Fallback to delta content or text
214
+ if content is None:
215
+ delta = choice.get("delta") or {}
216
+ content = delta.get("content", choice.get("text"))
217
+ else:
218
+ # Handle object-like choice (e.g. Pydantic model)
219
+ msg = getattr(choice, "message", None)
220
+ if msg is not None:
221
+ content = getattr(msg, "content", None)
222
+
223
+ if content is None:
224
+ delta = getattr(choice, "delta", None)
225
+ if delta is not None:
226
+ content = getattr(delta, "content", None)
227
+
228
+ if content is None:
229
+ content = getattr(choice, "text", None)
230
+
231
+ return content
232
+
233
+ def _parse_json_content(self, content: str) -> Dict[str, Any]:
234
+ cleaned = self._cleanup_fenced_json(content)
235
+ try:
236
+ return json.loads(cleaned)
237
+ except json.JSONDecodeError as e:
238
+ self.logger.warning(f"JSON parsing failed, original content: {cleaned[:500]}...")
239
+ raise ResponseParseError(f"JSON parsing failed: {e}") from e
240
+
241
+ def _cleanup_fenced_json(self, text: str) -> str:
242
+ text = text.strip()
243
+ if text.startswith("```"):
244
+ text = re.sub(r"^```(?:json)?\s*", "", text)
245
+ text = re.sub(r"\s*```$", "", text)
246
+ if not text.lstrip().startswith("{") and "{" in text and "}" in text:
247
+ start = text.find("{")
248
+ end = text.rfind("}")
249
+ if 0 <= start < end:
250
+ text = text[start : end + 1]
251
+ return text
252
+
253
+
254
+ class PromptRewriter:
255
+ def __init__(
256
+ self,
257
+ host: Optional[str] = None,
258
+ parser: Optional[ResponseParser] = None,
259
+ backend: Literal["our_rewriter"] = "our_rewriter",
260
+ ):
261
+ self.parser = parser or ResponseParser()
262
+ self.logger = logging.getLogger(__name__)
263
+ self.backend = backend.lower()
264
+
265
+ if self.backend == "our_rewriter":
266
+ self.api = OpenAIChatApi(
267
+ ApiConfig(
268
+ host=host,
269
+ user="",
270
+ apikey="EMPTY",
271
+ model="Qwen3-30B-A3B-SFT",
272
+ api_version="",
273
+ )
274
+ )
275
+ else:
276
+ raise ValueError(f"Invalid backend: {self.backend}")
277
+
278
+ def rewrite_prompt_and_infer_time(
279
+ self,
280
+ text: str,
281
+ prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT,
282
+ retry_config: Optional[RetryConfig] = None,
283
+ ) -> Tuple[float, str]:
284
+ self.logger.info("Start rewriting prompt...")
285
+ try:
286
+ result, cost, elapsed = self.parser.call_data_eval_with_retry(
287
+ self.api, prompt_format.format(text), retry_config
288
+ )
289
+ self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s")
290
+ return round(float(result["duration"]) / 30.0, 2), result["short_caption"]
291
+
292
+ except Exception as e:
293
+ self.logger.error(f"Prompt rewriting failed: {e}")
294
+ raise
295
+
296
+
297
+ if __name__ == "__main__":
298
+ # python -m hymotion.prompt_engineering.prompt_rewrite
299
+
300
+ logging.basicConfig(level=logging.INFO)
301
+ text = "person jumps after they runs"
302
+ prompt_rewriter = PromptRewriter(backend="our_rewriter")
303
+ result = prompt_rewriter.rewrite_prompt_and_infer_time(text)
304
+ print(result)
hymotion/utils/configs.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import copy
3
+ import os.path as osp
4
+ import platform
5
+ import re
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import types
10
+ import uuid
11
+ from importlib import import_module
12
+ from pathlib import Path
13
+ from typing import Any, Dict, Iterator, NoReturn, Optional, Union
14
+ import yaml
15
+
16
+ from .misc import import_modules_from_strings
17
+ from .path import check_file_exist
18
+
19
+ BASE_KEY = "_base_"
20
+ DELETE_KEY = "_delete_"
21
+ RESERVED_KEYS = ["filename", "text", "pretty_text"]
22
+
23
+
24
+ class Config:
25
+ def __init__(
26
+ self,
27
+ cfg_dict: Optional[dict] = None,
28
+ cfg_text: Optional[str] = None,
29
+ filename: Optional[str] = None,
30
+ ) -> None:
31
+ if cfg_dict is None:
32
+ cfg_dict = dict()
33
+ elif not isinstance(cfg_dict, dict):
34
+ raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
35
+ for key in cfg_dict:
36
+ if key in RESERVED_KEYS:
37
+ raise KeyError(f"{key} is reserved for config file")
38
+
39
+ if isinstance(filename, Path):
40
+ filename = str(filename)
41
+
42
+ super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
43
+ super(Config, self).__setattr__("_filename", filename)
44
+ if cfg_text:
45
+ text = cfg_text
46
+ elif filename:
47
+ with open(filename, "r") as f:
48
+ text = f.read()
49
+ else:
50
+ text = ""
51
+ super(Config, self).__setattr__("_text", text)
52
+
53
+ @staticmethod
54
+ def fromfile(
55
+ filename: str,
56
+ use_predefined_variables: bool = True,
57
+ import_custom_modules: bool = True,
58
+ ) -> "Config":
59
+ if isinstance(filename, Path):
60
+ filename = str(filename)
61
+ cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables)
62
+ if import_custom_modules and cfg_dict.get("custom_imports", None):
63
+ import_modules_from_strings(**cfg_dict["custom_imports"])
64
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
65
+
66
+ @staticmethod
67
+ def _file2dict(filename: str, use_predefined_variables: bool = True) -> tuple[dict, str]:
68
+ filename = osp.abspath(osp.expanduser(filename))
69
+ check_file_exist(filename)
70
+ fileExtname = osp.splitext(filename)[1]
71
+ if fileExtname not in [".py"]:
72
+ raise IOError("Only py type are supported now!")
73
+
74
+ cfg_dict = {}
75
+
76
+ with tempfile.TemporaryDirectory() as temp_config_dir:
77
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname)
78
+ if platform.system() == "Windows":
79
+ temp_config_file.close()
80
+ temp_config_name = osp.basename(temp_config_file.name)
81
+ # Substitute predefined variables
82
+ if use_predefined_variables:
83
+ Config._substitute_predefined_vars(filename, temp_config_file.name)
84
+ else:
85
+ shutil.copyfile(filename, temp_config_file.name)
86
+ # Substitute base variables from placeholders to strings
87
+ base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name)
88
+
89
+ if filename.endswith(".py"):
90
+ temp_module_name = osp.splitext(temp_config_name)[0]
91
+ sys.path.insert(0, temp_config_dir)
92
+ Config._validate_py_syntax(filename)
93
+ mod = import_module(temp_module_name)
94
+ sys.path.pop(0)
95
+ cfg_dict = {
96
+ name: value
97
+ for name, value in mod.__dict__.items()
98
+ if not name.startswith("__")
99
+ and not isinstance(value, types.ModuleType)
100
+ and not isinstance(value, types.FunctionType)
101
+ }
102
+ # delete imported module
103
+ del sys.modules[temp_module_name]
104
+
105
+ # close temp file
106
+ temp_config_file.close()
107
+
108
+ cfg_text = filename + "\n"
109
+ with open(filename, "r", encoding="utf-8") as f:
110
+ # Setting encoding explicitly to resolve coding issue on windows
111
+ cfg_text += f.read()
112
+
113
+ if BASE_KEY in cfg_dict:
114
+ cfg_dir = osp.dirname(filename)
115
+ base_filename = cfg_dict.pop(BASE_KEY)
116
+ base_filename = base_filename if isinstance(base_filename, list) else [base_filename]
117
+
118
+ cfg_dict_list = list()
119
+ cfg_text_list = list()
120
+ for f in base_filename:
121
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
122
+ cfg_dict_list.append(_cfg_dict)
123
+ cfg_text_list.append(_cfg_text)
124
+
125
+ base_cfg_dict = dict()
126
+ for c in cfg_dict_list:
127
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
128
+ if len(duplicate_keys) > 0:
129
+ raise KeyError("Duplicate key is not allowed among bases. " f"Duplicate keys: {duplicate_keys}")
130
+ base_cfg_dict.update(c)
131
+
132
+ # Substitute base variables from strings to their actual values
133
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict)
134
+ assert isinstance(cfg_dict, dict)
135
+
136
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
137
+ cfg_dict = base_cfg_dict
138
+
139
+ # merge cfg_text
140
+ cfg_text_list.append(cfg_text)
141
+ cfg_text = "\n".join(cfg_text_list)
142
+
143
+ return cfg_dict, cfg_text
144
+
145
+ @staticmethod
146
+ def _validate_py_syntax(filename: str) -> None:
147
+ with open(filename, "r", encoding="utf-8") as f:
148
+ # Setting encoding explicitly to resolve coding issue on windows
149
+ content = f.read()
150
+ try:
151
+ ast.parse(content)
152
+ except SyntaxError as e:
153
+ raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}")
154
+
155
+ @staticmethod
156
+ def _pre_substitute_base_vars(filename: str, temp_config_name: str) -> dict:
157
+ """Substitute base variable placehoders to string, so that parsing would work."""
158
+ with open(filename, "r", encoding="utf-8") as f:
159
+ config_file = f.read()
160
+ base_var_dict = {}
161
+ regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}"
162
+ base_vars = set(re.findall(regexp, config_file))
163
+ for base_var in base_vars:
164
+ randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}"
165
+ base_var_dict[randstr] = base_var
166
+ regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}"
167
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
168
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
169
+ tmp_config_file.write(config_file)
170
+ return base_var_dict
171
+
172
+ @staticmethod
173
+ def _substitute_base_vars(
174
+ cfg: Union[dict, list, tuple, str],
175
+ base_var_dict: dict,
176
+ base_cfg: dict,
177
+ ) -> Union[dict, list, tuple, str]:
178
+ """Substitute variable strings to their actual values."""
179
+ cfg = copy.deepcopy(cfg)
180
+
181
+ if isinstance(cfg, dict):
182
+ for k, v in cfg.items():
183
+ if isinstance(v, str) and v in base_var_dict:
184
+ new_v = base_cfg
185
+ for new_k in base_var_dict[v].split("."):
186
+ new_v = new_v[new_k]
187
+ cfg[k] = new_v
188
+ elif isinstance(v, (list, tuple, dict)):
189
+ cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg)
190
+ elif isinstance(cfg, tuple):
191
+ cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg)
192
+ elif isinstance(cfg, list):
193
+ cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg]
194
+ elif isinstance(cfg, str) and cfg in base_var_dict:
195
+ new_v = base_cfg
196
+ for new_k in base_var_dict[cfg].split("."):
197
+ new_v = new_v[new_k]
198
+ cfg = new_v
199
+
200
+ return cfg
201
+
202
+ @staticmethod
203
+ def _substitute_predefined_vars(filename: str, temp_config_name: str) -> None:
204
+ file_dirname = osp.dirname(filename)
205
+ file_basename = osp.basename(filename)
206
+ file_basename_no_extension = osp.splitext(file_basename)[0]
207
+ file_extname = osp.splitext(filename)[1]
208
+ support_templates = dict(
209
+ fileDirname=file_dirname,
210
+ fileBasename=file_basename,
211
+ fileBasenameNoExtension=file_basename_no_extension,
212
+ fileExtname=file_extname,
213
+ )
214
+ with open(filename, "r", encoding="utf-8") as f:
215
+ config_file = f.read()
216
+ for key, value in support_templates.items():
217
+ regexp = r"\{\{\s*" + str(key) + r"\s*\}\}"
218
+ value = value.replace("\\", "/")
219
+ config_file = re.sub(regexp, value, config_file)
220
+ with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file:
221
+ tmp_config_file.write(config_file)
222
+
223
+ @staticmethod
224
+ def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict:
225
+ b = b.copy()
226
+ for k, v in a.items():
227
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
228
+ k = int(k)
229
+ if len(b) <= k:
230
+ raise KeyError(f"Index {k} exceeds the length of list {b}")
231
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
232
+ elif isinstance(v, dict):
233
+ if k in b and not v.pop(DELETE_KEY, False):
234
+ allowed_types = (dict, list) if allow_list_keys else dict
235
+ if not isinstance(b[k], allowed_types):
236
+ raise TypeError(
237
+ f"{k}={v} in child config cannot inherit from "
238
+ f"base because {k} is a dict in the child config "
239
+ f"but is of type {type(b[k])} in base config. "
240
+ f"You may set `{DELETE_KEY}=True` to ignore the "
241
+ f"base config."
242
+ )
243
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
244
+ else:
245
+ b[k] = ConfigDict(v)
246
+ else:
247
+ b[k] = v
248
+ return b
249
+
250
+ def to_dict(self) -> Any:
251
+ def convert_configdict(obj):
252
+ if isinstance(obj, ConfigDict):
253
+ return {k: convert_configdict(v) for k, v in obj.items()}
254
+ elif isinstance(obj, dict):
255
+ return {k: convert_configdict(v) for k, v in obj.items()}
256
+ elif isinstance(obj, (list, tuple)):
257
+ return [convert_configdict(item) for item in obj]
258
+ else:
259
+ return obj
260
+
261
+ return convert_configdict(self._cfg_dict)
262
+
263
+ @classmethod
264
+ def from_dict(cls, cfg_dict: dict, filename: Optional[str] = None) -> "Config":
265
+ return cls(cfg_dict=cfg_dict, filename=filename)
266
+
267
+ def save_yaml(self, filename: str) -> None:
268
+ with open(filename, "w", encoding="utf-8") as f:
269
+ yaml.safe_dump(self.to_dict(), f, default_flow_style=False, indent=2)
270
+
271
+ @classmethod
272
+ def load_yaml(cls, filename: str) -> "Config":
273
+ with open(filename, "r", encoding="utf-8") as f:
274
+ cfg_dict = yaml.safe_load(f)
275
+ return cls.from_dict(cfg_dict, filename=filename)
276
+
277
+ def __repr__(self) -> str:
278
+ return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"
279
+
280
+ def __len__(self) -> int:
281
+ return len(self._cfg_dict)
282
+
283
+ def __getattr__(self, name: str) -> Any:
284
+ return getattr(self._cfg_dict, name)
285
+
286
+ def __getitem__(self, name: str) -> Any:
287
+ return self._cfg_dict.__getitem__(name)
288
+
289
+ def __setattr__(self, name: str, value: Any) -> None:
290
+ if isinstance(value, dict):
291
+ value = ConfigDict(value)
292
+ self._cfg_dict.__setattr__(name, value)
293
+
294
+ def __setitem__(self, name: str, value: Any) -> None:
295
+ if isinstance(value, dict):
296
+ value = ConfigDict(value)
297
+ self._cfg_dict.__setitem__(name, value)
298
+
299
+ def __iter__(self) -> Iterator[Any]:
300
+ return iter(self._cfg_dict)
301
+
302
+ def __getstate__(self) -> tuple[dict, str, str]:
303
+ return (self._cfg_dict, self._filename, self._text)
304
+
305
+ def __copy__(self) -> "Config":
306
+ cls = self.__class__
307
+ other = cls.__new__(cls)
308
+ other.__dict__.update(self.__dict__)
309
+
310
+ return other
311
+
312
+ def __deepcopy__(self, memo: dict) -> "Config":
313
+ cls = self.__class__
314
+ other = cls.__new__(cls)
315
+ memo[id(self)] = other
316
+
317
+ for key, value in self.__dict__.items():
318
+ super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
319
+
320
+ return other
321
+
322
+
323
+ class ConfigDict(Dict):
324
+ def __missing__(self, name: str) -> NoReturn:
325
+ raise KeyError(name)
326
+
327
+ def __getattr__(self, name: str) -> Any:
328
+ try:
329
+ return self[name]
330
+ except KeyError:
331
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
332
+
333
+ def to_dict(self) -> Any:
334
+ def convert_configdict(obj):
335
+ if isinstance(obj, ConfigDict):
336
+ return {k: convert_configdict(v) for k, v in obj.items()}
337
+ elif isinstance(obj, dict):
338
+ return {k: convert_configdict(v) for k, v in obj.items()}
339
+ elif isinstance(obj, (list, tuple)):
340
+ return [convert_configdict(item) for item in obj]
341
+ else:
342
+ return obj
343
+
344
+ return convert_configdict(dict(self))
hymotion/utils/geometry.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import Tensor
7
+
8
+
9
+ def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
10
+ """
11
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
12
+ using Gram--Schmidt orthogonalization per Section B of [1].
13
+ Args:
14
+ d6: 6D rotation representation, of size (*, 6)
15
+
16
+ Returns:
17
+ batch of rotation matrices of size (*, 3, 3)
18
+
19
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
20
+ On the Continuity of Rotation Representations in Neural Networks.
21
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
22
+ Retrieved from http://arxiv.org/abs/1812.07035
23
+ """
24
+
25
+ a1, a2 = d6[..., :3], d6[..., 3:]
26
+ b1 = F.normalize(a1, dim=-1)
27
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
28
+ b2 = F.normalize(b2, dim=-1)
29
+ b3 = torch.cross(b1, b2, dim=-1)
30
+ return torch.stack((b1, b2, b3), dim=-2)
31
+
32
+
33
+ def matrix_to_rotation_6d(matrix: Tensor) -> Tensor:
34
+ """
35
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
36
+ by dropping the last row. Note that 6D representation is not unique.
37
+ Args:
38
+ matrix: batch of rotation matrices of size (*, 3, 3)
39
+
40
+ Returns:
41
+ 6D rotation representation, of size (*, 6)
42
+
43
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
44
+ On the Continuity of Rotation Representations in Neural Networks.
45
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
46
+ Retrieved from http://arxiv.org/abs/1812.07035
47
+ """
48
+ batch_dim = matrix.size()[:-2]
49
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
50
+
51
+
52
+ def standardize_quaternion(quaternions: Tensor) -> Tensor:
53
+ """
54
+ Convert a unit quaternion to a standard form: one in which the real
55
+ part is non negative.
56
+
57
+ Args:
58
+ quaternions: Quaternions with real part first,
59
+ as tensor of shape (..., 4).
60
+
61
+ Returns:
62
+ Standardized quaternions as tensor of shape (..., 4).
63
+ """
64
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
65
+
66
+
67
+ def _sqrt_positive_part(x: Tensor) -> Tensor:
68
+ """Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0."""
69
+ ret = torch.zeros_like(x)
70
+ positive_mask = x > 0
71
+ if torch.is_grad_enabled():
72
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
73
+ else:
74
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
75
+ return ret
76
+
77
+
78
+ def matrix_to_quaternion(matrix: Tensor) -> Tensor:
79
+ """Convert rotations given as rotation matrices to quaternions.
80
+
81
+ Args:
82
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
83
+
84
+ Returns:
85
+ quaternions with real part first, as tensor of shape (..., 4).
86
+ """
87
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
88
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
89
+
90
+ batch_dim = matrix.shape[:-2]
91
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
92
+
93
+ q_abs = _sqrt_positive_part(
94
+ torch.stack(
95
+ [
96
+ 1.0 + m00 + m11 + m22,
97
+ 1.0 + m00 - m11 - m22,
98
+ 1.0 - m00 + m11 - m22,
99
+ 1.0 - m00 - m11 + m22,
100
+ ],
101
+ dim=-1,
102
+ )
103
+ )
104
+
105
+ # we produce the desired quaternion multiplied by each of r, i, j, k
106
+ quat_by_rijk = torch.stack(
107
+ [
108
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
109
+ # `int`.
110
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
111
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
112
+ # `int`.
113
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
114
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
115
+ # `int`.
116
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
117
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
118
+ # `int`.
119
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
120
+ ],
121
+ dim=-2,
122
+ )
123
+
124
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
125
+ # the candidate won't be picked.
126
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
127
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
128
+
129
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
130
+ # forall i; we pick the best-conditioned one (with the largest denominator)
131
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
132
+ return standardize_quaternion(out)
133
+
134
+
135
+ def quaternion_to_axis_angle(quaternions: Tensor) -> Tensor:
136
+ """Convert rotations given as quaternions to axis/angle.
137
+
138
+ Args:
139
+ quaternions: quaternions with real part first,
140
+ as tensor of shape (..., 4).
141
+
142
+ Returns:
143
+ Rotations given as a vector in axis angle form, as a tensor
144
+ of shape (..., 3), where the magnitude is the angle
145
+ turned anticlockwise in radians around the vector's
146
+ direction.
147
+ """
148
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
149
+ half_angles = torch.atan2(norms, quaternions[..., :1])
150
+ angles = 2 * half_angles
151
+ eps = 1e-6
152
+ small_angles = angles.abs() < eps
153
+ sin_half_angles_over_angles = torch.empty_like(angles)
154
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
155
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
156
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
157
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
158
+ return quaternions[..., 1:] / sin_half_angles_over_angles
159
+
160
+
161
+ def matrix_to_axis_angle(matrix: Tensor) -> Tensor:
162
+ """Convert rotations given as rotation matrices to axis/angle.
163
+
164
+ Args:
165
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
166
+
167
+ Returns:
168
+ Rotations given as a vector in axis angle form, as a tensor
169
+ of shape (..., 3), where the magnitude is the angle
170
+ turned anticlockwise in radians around the vector's
171
+ direction.
172
+ """
173
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
174
+
175
+
176
+ def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
177
+ """Convert rotations given as quaternions to rotation matrices.
178
+
179
+ Args:
180
+ quaternions: quaternions with real part first,
181
+ as tensor of shape (..., 4).
182
+
183
+ Returns:
184
+ Rotation matrices as tensor of shape (..., 3, 3).
185
+ """
186
+ r, i, j, k = torch.unbind(quaternions, -1)
187
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
188
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
189
+
190
+ o = torch.stack(
191
+ (
192
+ 1 - two_s * (j * j + k * k),
193
+ two_s * (i * j - k * r),
194
+ two_s * (i * k + j * r),
195
+ two_s * (i * j + k * r),
196
+ 1 - two_s * (i * i + k * k),
197
+ two_s * (j * k - i * r),
198
+ two_s * (i * k - j * r),
199
+ two_s * (j * k + i * r),
200
+ 1 - two_s * (i * i + j * j),
201
+ ),
202
+ -1,
203
+ )
204
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
205
+
206
+
207
+ def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor:
208
+ """Convert rotations given as axis/angle to quaternions.
209
+
210
+ Args:
211
+ axis_angle: Rotations given as a vector in axis angle form,
212
+ as a tensor of shape (..., 3), where the magnitude is
213
+ the angle turned anticlockwise in radians around the
214
+ vector's direction.
215
+
216
+ Returns:
217
+ quaternions with real part first, as tensor of shape (..., 4).
218
+ """
219
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
220
+ half_angles = angles * 0.5
221
+ eps = 1e-6
222
+ small_angles = angles.abs() < eps
223
+ sin_half_angles_over_angles = torch.empty_like(angles)
224
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
225
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
226
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
227
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
228
+ quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)
229
+ return quaternions
230
+
231
+
232
+ def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor:
233
+ """Convert rotations given as axis/angle to rotation matrices.
234
+
235
+ Args:
236
+ axis_angle: Rotations given as a vector in axis angle form,
237
+ as a tensor of shape (..., 3), where the magnitude is
238
+ the angle turned anticlockwise in radians around the
239
+ vector's direction.
240
+
241
+ Returns:
242
+ Rotation matrices as tensor of shape (..., 3, 3).
243
+ """
244
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
245
+
246
+
247
+ def get_T_w2c_from_wcparams(
248
+ global_orient_w: Tensor, transl_w: Tensor, global_orient_c: Tensor, transl_c: Tensor, offset: Tensor
249
+ ) -> Tensor:
250
+ """
251
+ Args:
252
+ global_orient_w: Tensor, (F, 3)
253
+ transl_w: Tensor, (F, 3)
254
+ global_orient_c: Tensor, (F, 3)
255
+ transl_c: Tensor, (F, 3)
256
+ offset: Tensor, (*, 3)
257
+ Returns:
258
+ T_w2c: Tensor, (F, 4, 4)
259
+ """
260
+ assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2
261
+ assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2
262
+
263
+ R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3)
264
+ t_w = transl_w # (F, 3)
265
+ R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3)
266
+ t_c = transl_c # (F, 3)
267
+
268
+ R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3)
269
+ t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3)
270
+ T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4)
271
+ T_w2c[..., :3, :3] = R_w2c # (F, 3, 3)
272
+ T_w2c[..., :3, 3] = t_w2c # (F, 3)
273
+ return T_w2c
274
+
275
+
276
+ def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]):
277
+ """
278
+ Args:
279
+ R_w2c: (*, 3, 3)
280
+ Returns:
281
+ R_c2gv: (*, 3, 3)
282
+ """
283
+ if isinstance(axis_gravity_in_w, list):
284
+ axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord
285
+ axis_z_in_c = torch.tensor([0, 0, 1]).float()
286
+
287
+ # get gv-coord axes in in c-coord
288
+ axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3)
289
+ axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1)
290
+ # normalize
291
+ axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True)
292
+ axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5)
293
+ axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv
294
+ axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1)
295
+
296
+ R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3)
297
+ R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3)
298
+ return R_c2gv
299
+
300
+
301
+ def get_c_rootparam(global_orient: Tensor, transl: Tensor, T_w2c: Tensor, offset: Tensor) -> Tuple[Tensor, Tensor]:
302
+ """
303
+ Args:
304
+ global_orient: Tensor, (F, 3)
305
+ transl: Tensor, (F, 3)
306
+ T_w2c: Tensor, (*, 4, 4)
307
+ offset: Tensor, (3,)
308
+ Returns:
309
+ R_c: Tensor, (F, 3)
310
+ t_c: Tensor, (F, 3)
311
+ """
312
+ assert global_orient.shape == transl.shape and len(global_orient.shape) == 2
313
+ R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3)
314
+ t_w = transl # (F, 3)
315
+
316
+ R_w2c = T_w2c[..., :3, :3] # (*, 3, 3)
317
+ t_w2c = T_w2c[..., :3, 3] # (*, 3)
318
+ if len(R_w2c.shape) == 2:
319
+ R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3)
320
+ t_w2c = t_w2c[None].expand(t_w.size(0), -1)
321
+
322
+ R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3)
323
+ t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3)
324
+ return R_c, t_c
325
+
326
+
327
+ def compute_cam_angvel(R_w2c, padding_last=True):
328
+ """
329
+ R_w2c : (F, 3, 3)
330
+ """
331
+ # R @ R0 = R1, so R = R1 @ R0^T
332
+ cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6)
333
+ # cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS
334
+ assert padding_last
335
+ cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6)
336
+ return cam_angvel.float()
337
+
338
+
339
+ def rot6d_to_rotation_matrix(rot6d):
340
+ """Convert 6D rotation representation to 3x3 rotation matrix.
341
+
342
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
343
+ Args:
344
+ rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
345
+ Returns:
346
+ rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
347
+ """
348
+ # x = rot6d.view(-1, 3, 2)
349
+ x = rot6d.view(*rot6d.shape[:-1], 3, 2)
350
+ a1 = x[..., 0]
351
+ a2 = x[..., 1]
352
+ b1 = F.normalize(a1, dim=-1)
353
+ b2 = F.normalize(a2 - torch.einsum("...i,...i->...", b1, a2).unsqueeze(-1) * b1, dim=-1)
354
+ b3 = torch.cross(b1, b2, dim=-1)
355
+ return torch.stack((b1, b2, b3), dim=-1)
356
+
357
+
358
+ def rotation_matrix_to_rot6d(rotation_matrix):
359
+ """Convert 3x3 rotation matrix to 6D rotation representation.
360
+
361
+ Args:
362
+ rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices.
363
+ Returns:
364
+ rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations.
365
+ """
366
+ v1 = rotation_matrix[..., 0:1]
367
+ v2 = rotation_matrix[..., 1:2]
368
+ rot6d = torch.cat([v1, v2], dim=-1).reshape(*v1.shape[:-2], 6)
369
+ return rot6d
370
+
371
+
372
+ def quaternion_to_rotation_matrix(quaternion):
373
+ """Convert quaternion coefficients to rotation matrix.
374
+
375
+ Args:
376
+ quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation.
377
+ Returns:
378
+ rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3)
379
+ """
380
+
381
+ norm_quaternion = quaternion
382
+ norm_quaternion = norm_quaternion / norm_quaternion.norm(p=2, dim=-1, keepdim=True)
383
+ w, x, y, z = norm_quaternion[..., 0], norm_quaternion[..., 1], norm_quaternion[..., 2], norm_quaternion[..., 3]
384
+
385
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
386
+ wx, wy, wz = w * x, w * y, w * z
387
+ xy, xz, yz = x * y, x * z, y * z
388
+
389
+ rotation_matrix = torch.stack(
390
+ [
391
+ w2 + x2 - y2 - z2,
392
+ 2 * xy - 2 * wz,
393
+ 2 * wy + 2 * xz,
394
+ 2 * wz + 2 * xy,
395
+ w2 - x2 + y2 - z2,
396
+ 2 * yz - 2 * wx,
397
+ 2 * xz - 2 * wy,
398
+ 2 * wx + 2 * yz,
399
+ w2 - x2 - y2 + z2,
400
+ ],
401
+ dim=-1,
402
+ )
403
+ rotation_matrix = rotation_matrix.view(*quaternion.shape[:-1], 3, 3)
404
+ return rotation_matrix
405
+
406
+
407
+ def quaternion_to_angle_axis(quaternion: Tensor) -> Tensor:
408
+ """
409
+ This function is borrowed from https://github.com/kornia/kornia
410
+
411
+ Convert quaternion vector to angle axis of rotation.
412
+
413
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
414
+
415
+ Args:
416
+ quaternion (Tensor): tensor with quaternions.
417
+
418
+ Return:
419
+ Tensor: tensor with angle axis of rotation.
420
+
421
+ Shape:
422
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
423
+ - Output: :math:`(*, 3)`
424
+
425
+ Example:
426
+ >>> quaternion = torch.rand(2, 4) # Nx4
427
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
428
+ """
429
+ if not torch.is_tensor(quaternion):
430
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion)))
431
+
432
+ if not quaternion.shape[-1] == 4:
433
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape))
434
+ # unpack input and compute conversion
435
+ q1: torch.Tensor = quaternion[..., 1]
436
+ q2: torch.Tensor = quaternion[..., 2]
437
+ q3: torch.Tensor = quaternion[..., 3]
438
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
439
+
440
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
441
+ cos_theta: torch.Tensor = quaternion[..., 0]
442
+ two_theta: torch.Tensor = 2.0 * torch.where(
443
+ cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta)
444
+ )
445
+
446
+ k_pos: torch.Tensor = two_theta / sin_theta
447
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
448
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
449
+
450
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
451
+ angle_axis[..., 0] += q1 * k
452
+ angle_axis[..., 1] += q2 * k
453
+ angle_axis[..., 2] += q3 * k
454
+ return angle_axis
455
+
456
+
457
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
458
+ """
459
+ This function is borrowed from https://github.com/kornia/kornia
460
+
461
+ Convert 3x4 rotation matrix to 4d quaternion vector
462
+
463
+ This algorithm is based on algorithm described in
464
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
465
+
466
+ Args:
467
+ rotation_matrix (Tensor): the rotation matrix to convert.
468
+
469
+ Return:
470
+ Tensor: the rotation in quaternion
471
+
472
+ Shape:
473
+ - Input: :math:`(N, 3, 4)`
474
+ - Output: :math:`(N, 4)`
475
+
476
+ Example:
477
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
478
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
479
+ """
480
+ if not torch.is_tensor(rotation_matrix):
481
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix)))
482
+
483
+ if len(rotation_matrix.shape) > 3:
484
+ raise ValueError("Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape))
485
+ if not rotation_matrix.shape[-2:] == (3, 4):
486
+ hom = (
487
+ torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
488
+ .reshape(1, 3, 1)
489
+ .expand(rotation_matrix.shape[0], -1, -1)
490
+ )
491
+ rotation_matrix = torch.cat([rotation_matrix, hom], dim=-1)
492
+
493
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
494
+
495
+ mask_d2 = rmat_t[:, 2, 2] < eps
496
+
497
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
498
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
499
+
500
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
501
+ q0 = torch.stack(
502
+ [rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]],
503
+ -1,
504
+ )
505
+ t0_rep = t0.repeat(4, 1).t()
506
+
507
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
508
+ q1 = torch.stack(
509
+ [rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]],
510
+ -1,
511
+ )
512
+ t1_rep = t1.repeat(4, 1).t()
513
+
514
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
515
+ q2 = torch.stack(
516
+ [rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2],
517
+ -1,
518
+ )
519
+ t2_rep = t2.repeat(4, 1).t()
520
+
521
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
522
+ q3 = torch.stack(
523
+ [t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]],
524
+ -1,
525
+ )
526
+ t3_rep = t3.repeat(4, 1).t()
527
+
528
+ mask_c0 = mask_d2 * mask_d0_d1
529
+ mask_c1 = mask_d2 * ~mask_d0_d1
530
+ mask_c2 = ~mask_d2 * mask_d0_nd1
531
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
532
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
533
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
534
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
535
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
536
+
537
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
538
+ q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa # noqa
539
+ q *= 0.5
540
+ return q
541
+
542
+
543
+ def rotation_matrix_to_angle_axis(rotation_matrix):
544
+ """
545
+ This function is borrowed from https://github.com/kornia/kornia
546
+
547
+ Convert 3x4 rotation matrix to Rodrigues vector
548
+
549
+ Args:
550
+ rotation_matrix (Tensor): rotation matrix.
551
+
552
+ Returns:
553
+ Tensor: Rodrigues vector transformation.
554
+
555
+ Shape:
556
+ - Input: :math:`(N, 3, 4)`
557
+ - Output: :math:`(N, 3)`
558
+
559
+ Example:
560
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
561
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
562
+ """
563
+ origin_shape = rotation_matrix.shape[:-2]
564
+ flat_rot = rotation_matrix.reshape(-1, *rotation_matrix.shape[-2:])
565
+ if flat_rot.shape[1:] == (3, 3):
566
+ rot_mat = flat_rot
567
+ hom = (
568
+ torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device)
569
+ .reshape(1, 3, 1)
570
+ .expand(rot_mat.shape[0], -1, -1)
571
+ )
572
+ flat_rot = torch.cat([rot_mat, hom], dim=-1)
573
+
574
+ quaternion = rotation_matrix_to_quaternion(flat_rot)
575
+ aa = quaternion_to_angle_axis(quaternion)
576
+ aa[torch.isnan(aa)] = 0.0
577
+ aa = aa.reshape(*origin_shape, 3)
578
+ return aa
579
+
580
+
581
+ def quat_to_rotmat(quat):
582
+ """Convert quaternion coefficients to rotation matrix.
583
+
584
+ Args:
585
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
586
+ Returns:
587
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
588
+ """
589
+ norm_quat = quat
590
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
591
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
592
+
593
+ B = quat.size(0)
594
+
595
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
596
+ wx, wy, wz = w * x, w * y, w * z
597
+ xy, xz, yz = x * y, x * z, y * z
598
+
599
+ rotMat = torch.stack(
600
+ [
601
+ w2 + x2 - y2 - z2,
602
+ 2 * xy - 2 * wz,
603
+ 2 * wy + 2 * xz,
604
+ 2 * wz + 2 * xy,
605
+ w2 - x2 + y2 - z2,
606
+ 2 * yz - 2 * wx,
607
+ 2 * xz - 2 * wy,
608
+ 2 * wx + 2 * yz,
609
+ w2 - x2 - y2 + z2,
610
+ ],
611
+ dim=1,
612
+ ).view(B, 3, 3)
613
+ return rotMat
614
+
615
+
616
+ def angle_axis_to_rotation_matrix(theta):
617
+ """Convert axis-angle representation to rotation matrix.
618
+
619
+ Args:
620
+ theta: size = [B, 3]
621
+ Returns:
622
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
623
+ """
624
+ origin_shape = theta.shape[:-1]
625
+ flat_theta = theta.reshape(-1, 3)
626
+ l1norm = torch.norm(flat_theta + 1e-8, p=2, dim=1)
627
+ angle = torch.unsqueeze(l1norm, -1)
628
+ normalized = torch.div(flat_theta, angle)
629
+ angle = angle * 0.5
630
+ v_cos = torch.cos(angle)
631
+ v_sin = torch.sin(angle)
632
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
633
+ rot_mat = quat_to_rotmat(quat)
634
+ return rot_mat.reshape(*origin_shape, 3, 3)
635
+
636
+
637
+ def rotation_matrix_to_euler_angles(rotation_matrix):
638
+ """Convert 3x3 rotation matrix to Euler angles."""
639
+ is_torch = False
640
+ if isinstance(rotation_matrix, Tensor):
641
+ is_torch = True
642
+ device = rotation_matrix.device
643
+ rotation_matrix = rotation_matrix.cpu().numpy()
644
+ from scipy.spatial.transform import Rotation
645
+
646
+ rot_flat = rotation_matrix.reshape(-1, 3, 3)
647
+ euler_angles = Rotation.from_matrix(rot_flat).as_euler("xyz", degrees=True)
648
+ if is_torch:
649
+ return torch.from_numpy(euler_angles).to(device)
650
+ return euler_angles
651
+
652
+
653
+ def euler_angles_to_rotation_matrix(euler_angles, degrees=True):
654
+ """Convert Euler angles to 3x3 rotation matrix.
655
+
656
+ Args:
657
+ euler_angles: Euler angles in xyz order, shape = [B, 3] or any shape with last dimension 3
658
+ degrees: Whether the angles are in degrees (True) or radians (False)
659
+
660
+ Returns:
661
+ Rotation matrix corresponding to the Euler angles, shape = [..., 3, 3]
662
+ """
663
+ from scipy.spatial.transform import Rotation
664
+
665
+ orig_shape = euler_angles.shape[:-1]
666
+ euler_flat = euler_angles.reshape(-1, 3)
667
+ rot_flat = Rotation.from_euler("xyz", euler_flat, degrees=degrees).as_matrix()
668
+ return rot_flat.reshape(*orig_shape, 3, 3)
669
+
670
+
671
+ def get_local_transl_vel(transl, global_orient_R, fps=30):
672
+ """
673
+ transl velocity is in local coordinate (or, SMPL-coord)
674
+ Args:
675
+ transl: (*, L, 3)
676
+ global_orient: (*, L, 3, 3)
677
+ Returns:
678
+ transl_vel: (*, L, 3)
679
+ """
680
+ transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3)
681
+ transl_vel = torch.cat([torch.zeros_like(transl_vel[:1]), transl_vel], dim=-2) # (B, L, 3) last-padding
682
+ transl_vel = transl_vel * fps
683
+
684
+ # v_local = R^T @ v_global
685
+ local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel)
686
+ return local_transl_vel
687
+
688
+
689
+ def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg):
690
+ s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2]
691
+ focal_length = K_fullimg[..., 0, 0]
692
+
693
+ icx = K_fullimg[..., 0, 2]
694
+ icy = K_fullimg[..., 1, 2]
695
+ sb = s * bbx_xys[..., 2]
696
+ cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9)
697
+ cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9)
698
+ tz = 2 * focal_length / (sb + 1e-9)
699
+
700
+ cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1)
701
+ return cam_t
702
+
703
+
704
+ def quaternion_fix_continuity(q: Tensor) -> Tensor:
705
+ """Force quaternion continuity across the time dimension by selecting the representation (q or -q) with minimal
706
+ distance (or, equivalently, maximal dot product) between two consecutive frames."""
707
+ assert q.ndim in (
708
+ 2,
709
+ 3,
710
+ ), f"Expected 3D tensor (L, J, 4), or 2D tensor (L, 4), but got shape {q.shape}"
711
+ assert q.shape[-1] == 4, f"Last dimension should be 4 for quaternions, got {q.shape[-1]}"
712
+ if q.shape[0] <= 1:
713
+ return q.clone() # single frame or empty sequence, no need to process
714
+
715
+ result = q.clone()
716
+ # compute the dot product between consecutive frames (L-1, J) or (L-1)
717
+ dot_products = torch.sum(q[1:] * q[:-1], dim=-1)
718
+ # find the negative dot product (indicates need to flip sign)
719
+ flip_mask = dot_products < 0
720
+ # accumulate the flip mask, ensure consistency
721
+ # if a frame needs to be flipped, all subsequent frames need to be flipped the same number of times
722
+ flip_mask = (torch.cumsum(flip_mask.int(), dim=0) % 2).bool()
723
+ # flip the sign of the frames that need to be flipped
724
+ result[1:][flip_mask] *= -1
725
+ return result
726
+
727
+
728
+ def rot_mat2trans_mat(rot_mat: np.ndarray) -> np.ndarray:
729
+ # assert rot_mat.shape == (3, 3)
730
+ trans_mat = np.identity(4)
731
+ if len(rot_mat.shape) == 2:
732
+ trans_mat = trans_mat
733
+ elif len(rot_mat.shape) == 3:
734
+ trans_mat = np.tile(trans_mat, [rot_mat.shape[0], 1, 1])
735
+ elif len(rot_mat.shape) == 4:
736
+ trans_mat = np.tile(trans_mat, [rot_mat.shape[0], rot_mat.shape[1], 1, 1])
737
+ else:
738
+ raise NotImplementedError
739
+ trans_mat[..., :3, :3] = rot_mat
740
+ return trans_mat
741
+
742
+
743
+ def trans2trans_mat(trans: np.ndarray) -> np.ndarray:
744
+ assert trans.shape[-1] == 3
745
+ assert (len(trans.shape) == 1) or (len(trans.shape) == 2) or (len(trans.shape) == 3), trans.shape
746
+ if len(trans.shape) == 1:
747
+ trans_mat = np.identity(4)
748
+ trans_mat[:3, 3] = trans
749
+ elif len(trans.shape) == 2:
750
+ trans_mat = np.tile(np.identity(4), [trans.shape[0], 1, 1])
751
+ trans_mat[:, :3, 3] = trans
752
+ elif len(trans.shape) == 3:
753
+ trans_mat = np.tile(np.identity(4), [trans.shape[0], trans.shape[1], 1, 1])
754
+ trans_mat[:, :, :3, 3] = trans
755
+ else:
756
+ raise NotImplementedError
757
+ return trans_mat
758
+
759
+
760
+ def gaussian_kernel1d(sigma: float, order: int, radius: int) -> np.ndarray:
761
+ """Computes a 1D Gaussian convolution kernel.
762
+
763
+ (from scipy)
764
+ """
765
+ if order < 0:
766
+ raise ValueError("order must be non-negative")
767
+ exponent_range = np.arange(order + 1)
768
+ sigma2 = sigma * sigma
769
+ x = np.arange(-radius, radius + 1)
770
+ phi_x = np.exp(-0.5 / sigma2 * x**2)
771
+ phi_x = phi_x / phi_x.sum()
772
+
773
+ if order == 0:
774
+ return phi_x
775
+ else:
776
+ # f(x) = q(x) * phi(x) = q(x) * exp(p(x))
777
+ # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
778
+ # p'(x) = -1 / sigma ** 2
779
+ # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
780
+ # coefficients of q(x)
781
+ q = np.zeros(order + 1)
782
+ q[0] = 1
783
+ D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
784
+ P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
785
+ Q_deriv = D + P
786
+ for _ in range(order):
787
+ q = Q_deriv.dot(q)
788
+ q = (x[:, None] ** exponent_range).dot(q)
789
+ return q * phi_x
790
+
791
+
792
+ def slice_seq_with_padding(whole_seq: np.ndarray, middle_idx: int, length: int) -> np.ndarray:
793
+ whole_seq_padded = whole_seq.copy()
794
+ if middle_idx - length // 2 < 0:
795
+ # need padding
796
+ l_pad_len = length // 2 - middle_idx
797
+ whole_seq_padded = np.concatenate([np.stack([whole_seq_padded[0]] * l_pad_len), whole_seq_padded], axis=0)
798
+ else:
799
+ l_pad_len = 0
800
+ if middle_idx + length - length // 2 > len(whole_seq):
801
+ r_pad_len = middle_idx + length - length // 2 - len(whole_seq)
802
+ whole_seq_padded = np.concatenate([whole_seq_padded, np.stack([whole_seq_padded[-1]] * r_pad_len)], axis=0)
803
+ else:
804
+ r_pad_len = 0
805
+ assert len(whole_seq_padded) == len(whole_seq) + l_pad_len + r_pad_len
806
+ middle_idx_padded = middle_idx + l_pad_len
807
+ assert middle_idx_padded - length // 2 >= 0
808
+ assert middle_idx_padded + length - length // 2 <= len(whole_seq_padded)
809
+ return whole_seq_padded[middle_idx_padded - length // 2 : middle_idx_padded - length // 2 + length]
810
+
811
+
812
+ def wavg_quaternion_markley(Q: np.ndarray, weights: np.ndarray) -> np.ndarray:
813
+ """
814
+ Averaging Quaternions.
815
+ This is a python implementation of Tolga Birdal's algorithm by https://stackoverflow.com/a/49690919
816
+
817
+ Arguments:
818
+ Q(ndarray): an Mx4 ndarray of quaternions.
819
+ weights(list): an M elements list, a weight for each quaternion.
820
+
821
+ refer to Tolga Birdal's matlab implementation on
822
+ https://ww2.mathworks.cn/matlabcentral/fileexchange/40098-tolgabirdal-averaging_quaternions?s_tid=prof_contriblnk&s_tid=mwa_osa_a
823
+ by Tolga Birdal
824
+ Q is an Mx4 matrix of quaternions. weights is an Mx1 vector, a weight for
825
+ each quaternion.
826
+ Qavg is the weighted average quaternion
827
+ This function is especially useful for example when clustering poses
828
+ after a matching process. In such cases a form of weighting per rotation
829
+ is available (e.g. number of votes), which can guide the trust towards a
830
+ specific pose. weights might then be interpreted as the vector of votes
831
+ per pose.
832
+ Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman.
833
+ "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30,
834
+ no. 4 (2007): 1193-1197.
835
+ """
836
+
837
+ # Form the symmetric accumulator matrix
838
+ # pdb.set_trace()
839
+ A = np.zeros((4, 4))
840
+ M = Q.shape[0]
841
+ wSum = 0
842
+
843
+ for i in range(M):
844
+ q = Q[i, :]
845
+ w_i = weights[i]
846
+ if q[0] < 0:
847
+ # handle the antipodal configuration
848
+ q = -q
849
+ A += w_i * (np.outer(q, q)) # rank 1 update
850
+ wSum += w_i
851
+
852
+ # scale
853
+ A /= wSum
854
+
855
+ # Get the eigenvector corresponding to largest eigen value
856
+ return np.linalg.eigh(A)[1][:, -1]
hymotion/utils/loaders.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+
5
+
6
+ def load_object(module_name, module_args, **extra_args):
7
+ module_args = module_args.copy()
8
+ module_path = ".".join(module_name.split(".")[:-1]).replace("/", ".")
9
+ module = importlib.import_module(module_path)
10
+ name = module_name.split(".")[-1]
11
+ if module_args is None:
12
+ module_args = {}
13
+ module_args.update(extra_args)
14
+ obj = getattr(module, name)(**module_args)
15
+ return obj
16
+
17
+
18
+ def load_module(module_name):
19
+ module_path = module_name.split(".")[0].replace("/", ".")
20
+ module = importlib.import_module(module_path)
21
+ name = module_name.split(".")[-1]
22
+ obj = getattr(module, name)
23
+ return obj
24
+
25
+
26
+ def check_cfg(cfg, global_dict, verbose=True):
27
+ for key, val in cfg.items():
28
+ if isinstance(val, dict):
29
+ check_cfg(val, global_dict, verbose)
30
+ elif isinstance(val, str):
31
+ if val.startswith("$"):
32
+ if verbose:
33
+ print(f" - Update {key} with {val} = {global_dict[val[1:]]}")
34
+ cfg[key] = global_dict[val[1:]]
35
+
36
+
37
+ def read_yaml(yamlname):
38
+ import yaml
39
+
40
+ with open(yamlname, "r", encoding="utf-8") as file:
41
+ try:
42
+ data = yaml.safe_load(file)
43
+ except yaml.constructor.ConstructorError:
44
+ file.seek(0)
45
+ data = yaml.load(file, Loader=yaml.FullLoader)
46
+ if hasattr(data, "to_dict"):
47
+ data = data.to_dict()
48
+ elif hasattr(data, "_cfg_dict"):
49
+ data = dict(data._cfg_dict)
50
+
51
+ return data
52
+
53
+
54
+ def write_yaml(data, yamlname):
55
+ import yaml
56
+
57
+ with open(yamlname, "w", encoding="utf-8") as file:
58
+ yaml.dump(data, file)
59
+
60
+
61
+ def check_input(data, verbose=True):
62
+ data_parent = {}
63
+ if "input" in data:
64
+ if verbose:
65
+ print(" - Check input file list")
66
+ for filename in data.pop("input"):
67
+ cfg_new = read_yaml(filename)
68
+ data_parent.update(cfg_new)
69
+ return data_parent
70
+
71
+
72
+ def merge_dict(dict_A, dict_B, key, verbose=True):
73
+ if isinstance(dict_A[key], dict):
74
+ dict_B = dict_B.copy()
75
+ for key2, val2 in dict_A[key].items():
76
+ if key2 in dict_B[key]:
77
+ merge_dict(dict_A[key], dict_B[key], key2, verbose)
78
+ dict_B[key].pop(key2)
79
+ if len(dict_B[key]) > 0:
80
+ if verbose:
81
+ print(f" - Create {key} with {dict_B[key]}")
82
+ for key2, val2 in dict_B[key].items():
83
+ dict_A[key][key2] = val2
84
+ else:
85
+ if verbose:
86
+ print(f" - Update {key} with {dict_B[key]}")
87
+ dict_A[key] = dict_B[key]
88
+
89
+
90
+ def read_config(cfgname, verbose=True):
91
+ data_base = read_yaml(cfgname)
92
+ data_parent = check_input(data_base, verbose)
93
+ # merge the data_base to data_parent
94
+ for key, val in data_parent.items():
95
+ if key in data_base:
96
+ merge_dict(data_parent, data_base, key, verbose)
97
+ if verbose:
98
+ print(data_parent[key])
99
+ data_base.pop(key)
100
+ data_parent.update(data_base)
101
+ data = data_parent
102
+ check_cfg(data, data, verbose)
103
+ return data
104
+
105
+
106
+ def update_config(config, args):
107
+ for key, value in vars(args).items():
108
+ if key in config.keys() and value is not None:
109
+ config[key] = value
110
+
111
+
112
+ def read_yaml_full(path):
113
+ import yaml
114
+
115
+ with open(path, "r") as f:
116
+ return yaml.load(f, Loader=yaml.FullLoader)
117
+
118
+
119
+ def check_ceph_path(path):
120
+ import os
121
+
122
+ if os.path.exists(path):
123
+ return path
124
+ else:
125
+ raise ValueError(f"{path} not found")
126
+
127
+
128
+ def read_json(filename):
129
+ with open(filename, "r", encoding="utf-8") as f:
130
+ return json.load(f)
131
+
132
+
133
+ def write_json(data, filename):
134
+ with open(filename, "w", encoding="utf-8") as f:
135
+ json.dump(data, f, ensure_ascii=False, indent=4)
136
+
137
+
138
+ def load_h5_dataset(filename, ds_name_list=None, parser=None):
139
+ import h5py
140
+
141
+ # ds for dataset
142
+ if "@" in filename:
143
+ filename, start_end = filename.split("@")
144
+ start = int(start_end.split(":")[0])
145
+ end = int(start_end.split(":")[1])
146
+ else:
147
+ start = None
148
+ end = None
149
+ assert os.path.isfile(filename), "cannot find: {}".format(filename)
150
+
151
+ def load_dict(d):
152
+ ds_dict = {}
153
+ for item in d.keys():
154
+ if ds_name_list is not None and item not in ds_name_list:
155
+ continue
156
+ if isinstance(d[item], h5py._hl.dataset.Dataset):
157
+ ds_dict[item] = d[item][()]
158
+ if parser is not None and item in parser:
159
+ ds_dict[item] = parser[item](ds_dict[item])
160
+ elif isinstance(d[item], h5py._hl.group.Group):
161
+ ds_dict[item] = load_dict(d[item])
162
+ for item in d.attrs.keys():
163
+ ds_dict[item] = d.attrs[item]
164
+ return ds_dict
165
+
166
+ with h5py.File(filename, "r") as f:
167
+ ds_dict = load_dict(f)
168
+ for item in f.attrs.keys():
169
+ ds_dict[item] = f.attrs[item]
170
+ if start is not None and end is not None:
171
+ for key in ["LclRotation", "LclTranslation"]:
172
+ ds_dict[key] = ds_dict[key][start:end]
173
+ return ds_dict
174
+
175
+
176
+ if __name__ == "__main__":
177
+ # hymotion.utils.loaders
178
+ network = load_object("hymotion.utils.base_example.ToyNetwork", {})
179
+ print(network)
180
+ network = load_object("hymotion/utils/base_example.ToyNetwork", {})
181
+ print(network)
182
+ load_object("diffusers.DDIMScheduler", {})
183
+ module = load_object("torch.nn.MSELoss", {"reduction": "none"})
184
+ print(module)
hymotion/utils/misc.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections.abc import Iterable, Sequence
3
+ from importlib import import_module
4
+ from itertools import repeat
5
+ from os import path as osp
6
+ from typing import Any, Callable, Optional, Tuple, Union
7
+
8
+
9
+ def is_str(x: Any) -> bool:
10
+ """Whether the input is an string instance.
11
+
12
+ Note: This method is deprecated since python 2 is no longer supported.
13
+ """
14
+ return isinstance(x, str)
15
+
16
+
17
+ def is_seq_of(seq: Any, expected_type: Any, seq_type: Any = None) -> bool:
18
+ """Check whether it is a sequence of some type.
19
+
20
+ Args:
21
+ seq (Sequence): The sequence to be checked.
22
+ expected_type (type): Expected type of sequence items.
23
+ seq_type (type, optional): Expected sequence type.
24
+ Returns:
25
+ bool: Whether the sequence is valid.
26
+ """
27
+ if seq_type is None:
28
+ exp_seq_type = Sequence
29
+ else:
30
+ assert isinstance(seq_type, type)
31
+ exp_seq_type = seq_type
32
+ if not isinstance(seq, exp_seq_type):
33
+ return False
34
+ for item in seq:
35
+ if not isinstance(item, expected_type):
36
+ return False
37
+ return True
38
+
39
+
40
+ def is_list_of(seq: Any, expected_type: Any) -> bool:
41
+ """Check whether it is a list of some type.
42
+
43
+ A partial method of :func:`is_seq_of`.
44
+ """
45
+ return is_seq_of(seq, expected_type, seq_type=list)
46
+
47
+
48
+ def is_tuple_of(seq: Any, expected_type: Any) -> bool:
49
+ """Check whether it is a tuple of some type.
50
+
51
+ A partial method of :func:`is_seq_of`.
52
+ """
53
+ return is_seq_of(seq, expected_type, seq_type=tuple)
54
+
55
+
56
+ def import_modules_from_strings(
57
+ imports: Union[list[str], str], allow_failed_imports: bool = False
58
+ ) -> Optional[list[Any]]:
59
+ if not imports:
60
+ return
61
+ single_import = False
62
+ if isinstance(imports, str):
63
+ single_import = True
64
+ imports = [imports]
65
+ if not isinstance(imports, list):
66
+ raise TypeError(f"custom_imports must be a list but got type {type(imports)}")
67
+ imported = []
68
+ for imp in imports:
69
+ if not isinstance(imp, str):
70
+ raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.")
71
+ try:
72
+ imported_tmp = import_module(imp)
73
+ except ImportError:
74
+ if allow_failed_imports:
75
+ warnings.warn(f"{imp} failed to import and is ignored.", UserWarning)
76
+ imported_tmp = None
77
+ else:
78
+ raise ImportError
79
+ imported.append(imported_tmp)
80
+ if single_import:
81
+ imported = imported[0]
82
+ return imported
83
+
84
+
85
+ def _ntuple(n: int) -> Callable:
86
+ def parse(x: Any) -> Tuple:
87
+ if isinstance(x, Iterable) and not isinstance(x, str):
88
+ x = tuple(x)
89
+ if len(x) == 1:
90
+ x = tuple(repeat(x[0], n))
91
+ return x
92
+ return tuple(repeat(x, n))
93
+
94
+ return parse
95
+
96
+
97
+ to_1tuple = _ntuple(1)
98
+ to_2tuple = _ntuple(2)
99
+ to_3tuple = _ntuple(3)
100
+ to_4tuple = _ntuple(4)
101
+
102
+
103
+ def seconds_to_hmsms(seconds: float) -> tuple[int, int, int, int]:
104
+ hours, remainder = divmod(seconds, 3600)
105
+ minutes, remainder = divmod(remainder, 60)
106
+ seconds, milliseconds = divmod(remainder, 1)
107
+ milliseconds *= 1000
108
+ return int(hours), int(minutes), int(seconds), int(milliseconds)
109
+
110
+
111
+ def frames_to_hmsms(frames: int, frame_rate: int = 30) -> tuple[int, int, int, int]:
112
+ seconds = frames / frame_rate
113
+ return seconds_to_hmsms(seconds)
114
+
115
+
116
+ def make_series(
117
+ data_root: str,
118
+ series_name: str,
119
+ count: int,
120
+ date: str,
121
+ postfix: str = "raw_caption/",
122
+ ):
123
+ return {
124
+ f"{series_name}_packed{i:02d}": {
125
+ "input_text_path": [
126
+ osp.join(
127
+ data_root,
128
+ series_name,
129
+ f"{series_name}_packed{i:02d}",
130
+ date,
131
+ f"{postfix}",
132
+ )
133
+ ]
134
+ }
135
+ for i in range(count)
136
+ }
hymotion/utils/motion_process.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def _hysteresis_and_morph(
9
+ prob: Tensor,
10
+ on_thr: float = 0.7,
11
+ off_thr: float = 0.5,
12
+ morph_min_len: int = 3,
13
+ morph_max_gap: int = 2,
14
+ ) -> Tensor:
15
+ L, K = prob.shape
16
+ device = prob.device
17
+ contact = torch.zeros_like(prob, dtype=torch.bool)
18
+ prev = torch.zeros((K,), dtype=torch.bool, device=device)
19
+ for t in range(L):
20
+ on = prob[t] > on_thr
21
+ off = prob[t] < off_thr
22
+ prev = torch.where(on, torch.ones_like(prev, dtype=torch.bool), prev)
23
+ prev = torch.where(off, torch.zeros_like(prev, dtype=torch.bool), prev)
24
+ contact[t] = prev
25
+
26
+ def morph_clean(x: Tensor, min_len: int = 3, max_gap: int = 2) -> Tensor:
27
+ x = x.clone()
28
+ cnt = 0
29
+ for tt in range(L):
30
+ if x[tt]:
31
+ cnt += 1
32
+ if (not x[tt]) or tt == L - 1:
33
+ if 0 < cnt < min_len:
34
+ x[tt - cnt : tt] = False
35
+ cnt = 0
36
+ gap = 0
37
+ last_on = -1
38
+ for tt in range(L):
39
+ if x[tt]:
40
+ if 0 < gap <= max_gap and last_on >= 0:
41
+ x[last_on + 1 : tt] = True
42
+ last_on = tt
43
+ gap = 0
44
+ else:
45
+ gap += 1
46
+ return x
47
+
48
+ return torch.stack(
49
+ [morph_clean(contact[:, j], morph_min_len, morph_max_gap) for j in range(K)],
50
+ dim=1,
51
+ )
52
+
53
+
54
+ def correct_translation_with_contact(
55
+ k3d: Tensor,
56
+ transl: Tensor,
57
+ prob: Tensor,
58
+ joint_ids: List[int] = [7, 10, 8, 11],
59
+ on_thr: float = 0.50,
60
+ off_thr: float = 0.30,
61
+ morph_min_len: int = 3,
62
+ morph_max_gap: int = 2,
63
+ eps: float = 1e-8,
64
+ ) -> Tensor:
65
+ if k3d.dim() == 3: # (L, J, 3) -> (1, L, J, 3)
66
+ k3d = k3d.unsqueeze(0)
67
+ if transl.dim() == 2: # (L, 3) -> (1, L, 3)
68
+ transl = transl.unsqueeze(0)
69
+ B, L, J, _ = k3d.shape
70
+ K = len(joint_ids)
71
+ if prob.dim() == 2: # (L, K)
72
+ contact = _hysteresis_and_morph(prob, on_thr, off_thr, morph_min_len, morph_max_gap) # (L, K)
73
+ contact = contact.unsqueeze(0).expand(B, -1, -1) # (B, L, K)
74
+ prob_b = prob.unsqueeze(0).expand(B, -1, -1) # (B, L, K)
75
+ elif prob.dim() == 3: # (B, L, K)
76
+ contact_list = []
77
+ prob_b = prob
78
+ for b in range(prob.shape[0]):
79
+ contact_list.append(_hysteresis_and_morph(prob[b], on_thr, off_thr, morph_min_len, morph_max_gap))
80
+ contact = torch.stack(contact_list, dim=0) # (B, L, K)
81
+ else:
82
+ raise ValueError("prob must be (L,K) or (B,L,K)")
83
+ pair_contact = contact[:, 1:] & contact[:, :-1] # (B, L-1, K)
84
+ pred_j3d_static = k3d[:, :, joint_ids, :] # (B, L, K, 3)
85
+ pred_j3d_static_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, K, 3)
86
+ w = 0.5 * (prob_b[:, 1:] + prob_b[:, :-1]) # (B, L-1, K)
87
+ w = w * pair_contact.float()
88
+ w_sum = w.sum(dim=2, keepdim=True).clamp_min(eps) # (B, L-1, 1)
89
+ drift = (pred_j3d_static_disp * w.unsqueeze(-1)).sum(dim=2) / w_sum # (B, L-1, 3)
90
+ drift[..., 1] = 0.0
91
+ w_disp = transl[:, 1:] - transl[:, :-1] # (B, L-1, 3)
92
+ w_disp_new = w_disp - drift
93
+ transl_fixed = torch.zeros_like(transl)
94
+ transl_fixed[:, 0] = transl[:, 0]
95
+ transl_fixed[:, 1:] = transl_fixed[:, :1] + torch.cumsum(w_disp_new, dim=1)
96
+ return transl_fixed.squeeze(0) if transl_fixed.shape[0] == 1 else transl_fixed
97
+
98
+
99
+ def smooth_quats(quats: np.ndarray, sigma: float = 1.0) -> np.ndarray:
100
+ from .geometry import gaussian_kernel1d, quaternion_fix_continuity, slice_seq_with_padding, wavg_quaternion_markley
101
+
102
+ if len(quats) == 0 or sigma <= 0:
103
+ return quats.copy()
104
+
105
+ q_all = quaternion_fix_continuity(torch.from_numpy(quats)).numpy()
106
+
107
+ results = q_all.copy()
108
+ truncate = 4.0
109
+ order = 0
110
+ lw = int(truncate * float(sigma) + 0.5)
111
+ weights = gaussian_kernel1d(sigma=sigma, order=order, radius=lw)[::-1]
112
+ kernel_len = len(weights)
113
+
114
+ for fr in range(len(q_all)):
115
+ cur_quats = slice_seq_with_padding(q_all, fr, kernel_len) # (K,4)
116
+ ref = cur_quats[kernel_len // 2 : kernel_len // 2 + 1] # (1,4)
117
+ dots = (cur_quats * ref).sum(axis=-1, keepdims=True) # (K,1)
118
+ cur_quats = np.where(dots < 0.0, -cur_quats, cur_quats)
119
+
120
+ results[fr, :] = wavg_quaternion_markley(cur_quats, weights)
121
+
122
+ return results.copy()
123
+
124
+
125
+ def smooth_rotation(
126
+ quats: np.ndarray,
127
+ # joint_names: List[str],
128
+ # smooth_joints: List[str],
129
+ sigma: float = 1.0,
130
+ ) -> np.ndarray:
131
+ from .geometry import quaternion_fix_continuity
132
+
133
+ if quats.ndim == 4:
134
+ is_batch = True
135
+ else:
136
+ is_batch = False
137
+ quats = quats[None, ...]
138
+ for b in range(quats.shape[0]):
139
+ for j_idx in range(quats.shape[2]):
140
+ cur_quats = quats[b, :, j_idx].copy()
141
+ cur_quats_t = quaternion_fix_continuity(torch.from_numpy(cur_quats)).numpy()
142
+ quats[b, :, j_idx] = smooth_quats(cur_quats_t, sigma=sigma)
143
+ if not is_batch:
144
+ quats = quats.squeeze(0)
145
+ return quats
146
+
147
+
148
+ def unwrap_euler_over_time(xyz: torch.Tensor) -> torch.Tensor:
149
+ # xyz: (B, L, J, 3)
150
+ # y[t] = y[0] + cumsum(wrap(Δy))
151
+ y = xyz.clone()
152
+ dy = torch.atan2(torch.sin(y[:, 1:] - y[:, :-1]), torch.cos(y[:, 1:] - y[:, :-1]))
153
+ y[:, 1:] = y[:, :1] + torch.cumsum(dy, dim=1)
154
+ return y
hymotion/utils/path.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import platform
4
+ from pathlib import Path
5
+ from typing import Any, Generator, List, Optional, Union
6
+
7
+ from .misc import is_str
8
+
9
+ if platform.system() == "Windows":
10
+ import regex as re
11
+ else:
12
+ import re
13
+
14
+
15
+ def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
16
+ if not osp.isfile(filename):
17
+ raise FileNotFoundError(msg_tmpl.format(filename))
18
+
19
+
20
+ def mkdir_or_exist(dir_name: str, mode: int = 0o777) -> None:
21
+ if dir_name == "":
22
+ return
23
+ dir_name = osp.expanduser(dir_name)
24
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
25
+
26
+
27
+ def symlink(src: str, dst: str, overwrite: bool = True, **kwargs) -> None:
28
+ if os.path.lexists(dst) and overwrite:
29
+ os.remove(dst)
30
+ os.symlink(src, dst, **kwargs)
31
+
32
+
33
+ def is_filepath(x: Any) -> bool:
34
+ return is_str(x) or isinstance(x, Path)
35
+
36
+
37
+ def scandir(
38
+ dir_path: Union[str, Path],
39
+ suffix: Optional[str] = None,
40
+ recursive: bool = False,
41
+ case_sensitive: bool = True,
42
+ ) -> Generator[str, None, None]:
43
+ """Scan a directory to find the interested files.
44
+
45
+ Args:
46
+ dir_path (str | :obj:`Path`): Path of the directory.
47
+ suffix (str | tuple(str), optional): File suffix that we are
48
+ interested in. Default: None.
49
+ recursive (bool, optional): If set to True, recursively scan the
50
+ directory. Default: False.
51
+ case_sensitive (bool, optional) : If set to False, ignore the case of
52
+ suffix. Default: True.
53
+ Returns:
54
+ A generator for all the interested files with relative paths.
55
+ """
56
+ if isinstance(dir_path, (str, Path)):
57
+ dir_path = str(dir_path)
58
+ else:
59
+ raise TypeError('"dir_path" must be a string or Path object')
60
+
61
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
62
+ raise TypeError('"suffix" must be a string or tuple of strings')
63
+
64
+ if suffix is not None and not case_sensitive:
65
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix)
66
+
67
+ root = dir_path
68
+
69
+ def _scandir(
70
+ dir_path: Union[str, Path],
71
+ suffix: Optional[str],
72
+ recursive: bool,
73
+ case_sensitive: bool,
74
+ ) -> Generator[str, None, None]:
75
+ for entry in os.scandir(dir_path):
76
+ if not entry.name.startswith(".") and entry.is_file():
77
+ rel_path = osp.relpath(entry.path, root)
78
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
79
+ if suffix is None or _rel_path.endswith(suffix):
80
+ yield rel_path
81
+ elif recursive and os.path.isdir(entry.path):
82
+ # scan recursively if entry.path is a directory
83
+ yield from _scandir(entry.path, suffix, recursive, case_sensitive)
84
+
85
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
86
+
87
+
88
+ def find_files(directory, pattern, recursive=True, abspath=False) -> List[str]:
89
+ regex = re.compile(pattern)
90
+ file_list = []
91
+ for root, _, files in os.walk(directory):
92
+ for f in files:
93
+ if regex.match(f) is not None:
94
+ file_list.append(os.path.join(root, f))
95
+ if not recursive:
96
+ break
97
+ map_func = os.path.abspath if abspath else os.path.relpath
98
+ return list(map(map_func, sorted(file_list)))
99
+
100
+
101
+ def natural_keys(text: str, retoken: str = r"[a-zA-Z]*(\d+)[a-zA-Z_]*[\.].*", n: int = 1) -> Union[int, str]:
102
+ def _atoi(text: str) -> Union[int, str]:
103
+ return int(text) if text.isdigit() else text.lower()
104
+
105
+ return _atoi(re.split(retoken, text)[n])
106
+
107
+
108
+ listdirs = lambda root: [osp.join(base, d) for base, dirs, _ in os.walk(root) if dirs for d in dirs]
109
+
110
+ listfiles = lambda root: [f for base, _, files in os.walk(root) if files for f in files]
111
+
112
+
113
+ def parse_dirs_and_sort(
114
+ input_dirs: Union[list, str],
115
+ suffix: str,
116
+ is_sort: bool = False,
117
+ with_prefix: bool = True,
118
+ ) -> List[str]:
119
+ if isinstance(input_dirs, list):
120
+ input_dirs_list = []
121
+ for iter_input_dir in input_dirs:
122
+ if osp.isdir(iter_input_dir):
123
+ input_dirs_list += [
124
+ osp.join(iter_input_dir, x) if with_prefix else x
125
+ for x in scandir(
126
+ iter_input_dir,
127
+ suffix=suffix,
128
+ recursive=True,
129
+ case_sensitive=False,
130
+ )
131
+ ]
132
+ elif osp.isfile(iter_input_dir):
133
+ if iter_input_dir.endswith(suffix):
134
+ input_dirs_list += [iter_input_dir]
135
+ else:
136
+ raise ValueError(f"Input path {iter_input_dir} is not exist.")
137
+ elif isinstance(input_dirs, str):
138
+ if osp.isdir(input_dirs):
139
+ input_dirs_list = [
140
+ osp.join(input_dirs, x) if with_prefix else x
141
+ for x in scandir(input_dirs, suffix=suffix, recursive=True, case_sensitive=False)
142
+ ]
143
+ elif osp.isfile(input_dirs):
144
+ if input_dirs.endswith(suffix):
145
+ input_dirs_list = [input_dirs]
146
+ else:
147
+ input_dirs_list = []
148
+ else:
149
+ raise ValueError(f"Input path {input_dirs} is not exist.")
150
+ else:
151
+ raise ValueError("Only support list or str input.")
152
+
153
+ if is_sort:
154
+ try:
155
+ try:
156
+ input_dirs_list = sorted(
157
+ input_dirs_list,
158
+ key=lambda text: (
159
+ natural_keys(text, retoken=r"[a-zA-Z]*(\d+)_[0-9a-zA-Z_]*[\.].*", n=1),
160
+ natural_keys(text, retoken=r"[0-9a-zA-Z]*_(\d+)[a-zA-Z_]*[\.].*", n=1),
161
+ ),
162
+ )
163
+ except:
164
+ input_dirs_list = sorted(input_dirs_list, key=lambda text: (natural_keys(text)))
165
+ except:
166
+ input_dirs_list = sorted(input_dirs_list, key=lambda text: text)
167
+
168
+ return input_dirs_list
hymotion/utils/smplh2fbx.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import tempfile
6
+
7
+ import fbx
8
+ import numpy as np
9
+ import torch
10
+ from transforms3d.euler import mat2euler
11
+
12
+ from .geometry import angle_axis_to_rotation_matrix, rot_mat2trans_mat, trans2trans_mat
13
+
14
+ # yapf: disable
15
+ SMPLH_JOINT2NUM = {
16
+ "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3,
17
+ "L_Knee": 4, "R_Knee": 5, "Spine2": 6,
18
+ "L_Ankle": 7, "R_Ankle": 8,
19
+ "Spine3": 9,
20
+ "L_Foot": 10, "R_Foot": 11,
21
+ "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15,
22
+ "L_Shoulder": 16, "R_Shoulder": 17,
23
+ "L_Elbow": 18, "R_Elbow": 19,
24
+ "L_Wrist": 20, "R_Wrist": 21,
25
+ # "Jaw": 22, "L_Eye": 23, "R_Eye": 24,
26
+ "L_Index1": 22, "L_Index2": 23, "L_Index3": 24,
27
+ "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27,
28
+ "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30,
29
+ "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33,
30
+ "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36,
31
+ "R_Index1": 37, "R_Index2": 38, "R_Index3": 39,
32
+ "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42,
33
+ "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45,
34
+ "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48,
35
+ "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51,
36
+ }
37
+ # yapf: enable
38
+
39
+
40
+ def _parse_obj_file(obj_path):
41
+ vertices = []
42
+ uv_coords = []
43
+ faces = []
44
+ uv_faces = []
45
+
46
+ with open(obj_path, "r") as f:
47
+ for line in f:
48
+ line = line.strip()
49
+ if line.startswith("v "):
50
+ parts = line.split()
51
+ vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
52
+ elif line.startswith("vt "):
53
+ parts = line.split()
54
+ uv_coords.append([float(parts[1]), float(parts[2])])
55
+ elif line.startswith("f "):
56
+ parts = line.split()
57
+ face_vertices = []
58
+ face_uvs = []
59
+ for part in parts[1:]:
60
+ indices = part.split("/")
61
+ face_vertices.append(int(indices[0]) - 1)
62
+ if len(indices) > 1 and indices[1]:
63
+ face_uvs.append(int(indices[1]) - 1)
64
+
65
+ if len(face_vertices) == 3:
66
+ faces.append(face_vertices)
67
+ if len(face_uvs) == 3:
68
+ uv_faces.append(face_uvs)
69
+
70
+ return np.array(vertices), np.array(uv_coords), np.array(faces), np.array(uv_faces)
71
+
72
+
73
+ def _blend_shapes(betas: torch.Tensor, shape_disps: torch.Tensor) -> torch.Tensor:
74
+ """Calculates the per vertex displacement due to the blend shapes.
75
+
76
+ Parameters
77
+ ----------
78
+ betas : torch.tensor Bx(num_betas)
79
+ Blend shape coefficients
80
+ shape_disps: torch.tensor Vx3x(num_betas)
81
+ Blend shapes
82
+
83
+ Returns
84
+ -------
85
+ torch.tensor BxVx3
86
+ The per-vertex displacement due to shape deformation
87
+ """
88
+
89
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
90
+ # i.e. Multiply each shape displacement by its corresponding beta and
91
+ # then sum them.
92
+ blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps])
93
+ return blend_shape
94
+
95
+
96
+ def _vertices2joints(J_regressor: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor:
97
+ """Calculates the 3D joint locations from the vertices.
98
+
99
+ Parameters
100
+ ----------
101
+ J_regressor : torch.tensor JxV
102
+ The regressor array that is used to calculate the joints from the
103
+ position of the vertices
104
+ vertices : torch.tensor BxVx3
105
+ The tensor of mesh vertices
106
+
107
+ Returns
108
+ -------
109
+ torch.tensor BxJx3
110
+ The location of the joints
111
+ """
112
+
113
+ return torch.einsum("bik,ji->bjk", [vertices, J_regressor])
114
+
115
+
116
+ def _addSmplXMesh(fbxScene, v_posed, faces, uv_coords=None, uv_faces=None):
117
+ # Obtain a reference to the scene's root node.
118
+ rootNode = fbxScene.GetRootNode()
119
+
120
+ # Create a new node in the scene.
121
+ geometryNode = fbx.FbxNode.Create(fbxScene, "Geometry")
122
+ rootNode.AddChild(geometryNode)
123
+
124
+ # Create a new mesh node attribute in the scene, and
125
+ # set it as the new node's attribute
126
+ mesh = fbx.FbxMesh.Create(fbxScene, "body")
127
+ geometryNode.SetNodeAttribute(mesh)
128
+
129
+ # Define the new mesh's control points.
130
+ # v_posed, faces = smplx['v_posed'], smplx['faces']
131
+ v_posed = np.array(v_posed)
132
+ faces = np.array(faces)
133
+
134
+ minValue = np.min(v_posed)
135
+ maxValue = np.max(v_posed)
136
+ # print(f"min = {minValue}, max = {maxValue}")
137
+ # print("min = {}, max = {}".format(minValue, maxValue))
138
+
139
+ # m = axangle2mat((1, 0, 0), np.radians(180))
140
+
141
+ mesh.InitControlPoints(v_posed.shape[0])
142
+ for i in range(v_posed.shape[0]):
143
+ v = v_posed[i, :]
144
+ # v = np.matmul(m, v)
145
+ vertex = fbx.FbxVector4(v[0], v[1], v[2])
146
+ mesh.SetControlPointAt(vertex, i)
147
+
148
+ for i in range(faces.shape[0]):
149
+ mesh.BeginPolygon(i)
150
+ mesh.AddPolygon(faces[i, 0])
151
+ mesh.AddPolygon(faces[i, 1])
152
+ mesh.AddPolygon(faces[i, 2])
153
+ mesh.EndPolygon()
154
+
155
+ if uv_coords is not None and uv_faces is not None:
156
+ uv_layer = mesh.CreateElementUV("UVSet")
157
+ uv_layer.SetMappingMode(fbx.FbxLayerElement.EMappingMode.eByPolygonVertex)
158
+ uv_layer.SetReferenceMode(fbx.FbxLayerElement.EReferenceMode.eIndexToDirect)
159
+
160
+ uv_array = uv_layer.GetDirectArray()
161
+ for i in range(len(uv_coords)):
162
+ uv_array.Add(fbx.FbxVector2(uv_coords[i][0], uv_coords[i][1]))
163
+
164
+ uv_index_array = uv_layer.GetIndexArray()
165
+ for i in range(len(uv_faces)):
166
+ for j in range(3):
167
+ uv_index_array.Add(uv_faces[i][j])
168
+ return geometryNode
169
+
170
+
171
+ def _addSmplXSkeleton(fbxManager, fbxScene, trans, joint2num, kintree_table):
172
+ num2joint = ["" for key in joint2num]
173
+ for key, value in joint2num.items():
174
+ num2joint[value] = key
175
+
176
+ # trans = np.array(trans)
177
+
178
+ # Obtain a reference to the scene's root node.
179
+ rootNode = fbxScene.GetRootNode()
180
+
181
+ # Create a new node in the scene.
182
+ referenceNode = fbx.FbxNode.Create(fbxScene, "Reference")
183
+ rootNode.AddChild(referenceNode)
184
+
185
+ # Create skeletons
186
+ skeletonNodes = []
187
+ for nth in range(len(kintree_table)):
188
+ skeleton = fbx.FbxSkeleton.Create(fbxManager, "")
189
+ skeleton.SetSkeletonType(fbx.FbxSkeleton.EType.eRoot if nth == -1 else fbx.FbxSkeleton.EType.eLimbNode)
190
+
191
+ node = fbx.FbxNode.Create(fbxScene, num2joint[nth])
192
+ node.SetNodeAttribute(skeleton)
193
+
194
+ node.LclTranslation.Set(fbx.FbxDouble3(trans[nth, 0], trans[nth, 1], trans[nth, 2]))
195
+
196
+ skeletonNodes.append(node)
197
+
198
+ if kintree_table[nth] != -1:
199
+ skeletonNodes[kintree_table[nth]].AddChild(node)
200
+
201
+ referenceNode.AddChild(skeletonNodes[0])
202
+ return referenceNode, skeletonNodes
203
+
204
+
205
+ def _addSkiningWeight(fbxScene, lbs_weights, geometryNode, skeletonNodes):
206
+ clusters = []
207
+ for i in range(lbs_weights.shape[1]):
208
+ cluster = fbx.FbxCluster.Create(fbxScene, "")
209
+ cluster.SetLink(skeletonNodes[i])
210
+ cluster.SetLinkMode(fbx.FbxCluster.ELinkMode.eTotalOne)
211
+
212
+ for j in range(lbs_weights.shape[0]):
213
+ weight = lbs_weights[j, i]
214
+ if weight > 0:
215
+ cluster.AddControlPointIndex(j, weight)
216
+
217
+ clusters.append(cluster)
218
+
219
+ # Now we have the Geometry and the skeleton correctly positioned,
220
+ # set the transform and TransformLink matrix accordingly.
221
+ matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(geometryNode)
222
+ for cluster in clusters:
223
+ cluster.SetTransformMatrix(matrix)
224
+
225
+ for i in range(len(skeletonNodes)):
226
+ matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(skeletonNodes[i])
227
+ clusters[i].SetTransformLinkMatrix(matrix)
228
+
229
+ # Add the clusters to the patch by creating a skin and adding those clusters to that skin.
230
+ skin = fbx.FbxSkin.Create(fbxScene, "")
231
+ for cluster in clusters:
232
+ skin.AddCluster(cluster)
233
+ geometryNode.GetNodeAttribute().AddDeformer(skin)
234
+
235
+
236
+ def _storeBindPose(fbxScene, geometryNode):
237
+ # In the bind pose, we must store all the link's global matrix at the
238
+ # time of the bind.
239
+ # Plus, we must store all the parent(s) global matrix of a link, even
240
+ # if they are not themselves deforming any model.
241
+
242
+ clusteredNodes = []
243
+ if geometryNode and geometryNode.GetNodeAttribute():
244
+ skinCount = 0
245
+ clusterCount = 0
246
+ attributeType = geometryNode.GetNodeAttribute().GetAttributeType()
247
+ if attributeType in (
248
+ fbx.FbxNodeAttribute.EType.eMesh,
249
+ fbx.FbxNodeAttribute.EType.eNurbs,
250
+ fbx.FbxNodeAttribute.EType.ePatch,
251
+ ):
252
+ skinCount = geometryNode.GetNodeAttribute().GetDeformerCount(fbx.FbxDeformer.EDeformerType.eSkin)
253
+ for i in range(skinCount):
254
+ skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin)
255
+ clusterCount += skin.GetClusterCount()
256
+
257
+ if clusterCount:
258
+ for i in range(skinCount):
259
+ skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin)
260
+ clusterCount = skin.GetClusterCount()
261
+ for j in range(clusterCount):
262
+ link = skin.GetCluster(j).GetLink()
263
+ _addNodeRecursively(clusteredNodes, link)
264
+
265
+ # Add the geometry to the pose
266
+ clusteredNodes += [geometryNode]
267
+
268
+ # Now create a bind pose with the link list
269
+ if len(clusteredNodes):
270
+ # A pose must be named. Arbitrarily use the name of the geometry node.
271
+ pose = fbx.FbxPose.Create(fbxScene, geometryNode.GetName())
272
+ pose.SetIsBindPose(True)
273
+
274
+ for node in clusteredNodes:
275
+ bindMatrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(node)
276
+ pose.Add(node, fbx.FbxMatrix(bindMatrix))
277
+
278
+ fbxScene.AddPose(pose)
279
+
280
+
281
+ def _addNodeRecursively(nodeArray, node):
282
+ """Add the specified node to the node array.
283
+
284
+ Also, add recursively all the parent node of the specified node to the array.
285
+ """
286
+ if node:
287
+ _addNodeRecursively(nodeArray, node.GetParent())
288
+ found = False
289
+ if node in nodeArray:
290
+ if node.GetName() == node.GetName():
291
+ found = True
292
+ if not found:
293
+ nodeArray += [node]
294
+
295
+
296
+ def _animateGlobalTransformsFromTransMat(animLayer, referenceNode, global_translation, frameDuration):
297
+ _animateSingleChannel(animLayer, referenceNode.LclTranslation, "X", global_translation, frameDuration)
298
+ _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Y", global_translation, frameDuration)
299
+ _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Z", global_translation, frameDuration)
300
+
301
+
302
+ def _animateSingleChannel(animLayer, component, name, values, frameDuration):
303
+ ncomp = 0
304
+
305
+ if name == "X":
306
+ ncomp = 0
307
+ elif name == "Y":
308
+ ncomp = 1
309
+ elif name == "Z":
310
+ ncomp = 2
311
+
312
+ time = fbx.FbxTime()
313
+ curve = component.GetCurve(animLayer, name, True)
314
+ curve.KeyModifyBegin()
315
+ for nth in range(len(values)):
316
+ time.SetSecondDouble(nth * frameDuration)
317
+ keyIndex = curve.KeyAdd(time)[0]
318
+ curve.KeySetValue(keyIndex, values[nth][ncomp])
319
+ curve.KeySetInterpolation(
320
+ keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant
321
+ ) # NOTE: using eInterpolationCubic to do interpolation causes error.
322
+ curve.KeyModifyEnd()
323
+
324
+
325
+ def _animateRotationKeyFrames(animLayer, node, transforms_mat, frameDuration):
326
+ rotations = []
327
+ for nth in range(len(transforms_mat)):
328
+ rotations.append(np.rad2deg(mat2euler(transforms_mat[nth][0:3, 0:3], axes="sxyz")))
329
+
330
+ _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration)
331
+ _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration)
332
+ _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration)
333
+
334
+
335
+ def _animateTranslationKeyFrames(animLayer, node, transforms_mat, frameDuration):
336
+ translations = []
337
+ for nth in range(len(transforms_mat)):
338
+ translations.append(transforms_mat[nth][0:3, 3])
339
+
340
+ _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration)
341
+ _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration)
342
+ _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration)
343
+
344
+
345
+ def _animateScalingKeyFrames(animLayer, node, transforms_mat, frameDuration):
346
+ scalings = []
347
+ for nth in range(len(transforms_mat)):
348
+ scalings.append(
349
+ np.array(
350
+ (
351
+ transforms_mat[nth][0, 0],
352
+ transforms_mat[nth][1, 1],
353
+ transforms_mat[nth][2, 2],
354
+ )
355
+ )
356
+ )
357
+
358
+ _animateSingleChannel(animLayer, node.LclTranslation, "X", scalings, frameDuration)
359
+ _animateSingleChannel(animLayer, node.LclTranslation, "Y", scalings, frameDuration)
360
+ _animateSingleChannel(animLayer, node.LclTranslation, "Z", scalings, frameDuration)
361
+
362
+
363
+ def _animateSkeleton(fbxScene, skeletonNodes, frames, frameRate, name="Take1"):
364
+ frameDuration = 1.0 / frameRate
365
+
366
+ if name != "Take1":
367
+ subs = name.split("/")
368
+ name = subs[-1][:-5]
369
+
370
+ animStack = fbx.FbxAnimStack.Create(fbxScene, name)
371
+ animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer")
372
+ animStack.AddMember(animLayer)
373
+ _animateGlobalTransformsFromTransMat(
374
+ animLayer=animLayer,
375
+ referenceNode=skeletonNodes[0],
376
+ global_translation=frames[:, 0, :3, 3],
377
+ frameDuration=frameDuration,
378
+ )
379
+
380
+ for nId in range(len(skeletonNodes)):
381
+ _animateRotationKeyFrames(
382
+ animLayer=animLayer,
383
+ node=skeletonNodes[nId],
384
+ transforms_mat=frames[:, nId],
385
+ frameDuration=frameDuration,
386
+ )
387
+
388
+
389
+ def _saveScene(filename, fbxManager, fbxScene):
390
+ exporter = fbx.FbxExporter.Create(fbxManager, "")
391
+ isInitialized = exporter.Initialize(filename)
392
+
393
+ if isInitialized is False:
394
+ raise Exception(
395
+ "Exporter failed to initialized. Error returned: {}".format(exporter.GetStatus().GetErrorString())
396
+ )
397
+
398
+ exporter.Export(fbxScene)
399
+ exporter.Destroy()
400
+
401
+
402
+ def _get_offsets_from_beta(beta, smplx_params, return_template_mesh=True):
403
+ v_template = torch.FloatTensor(smplx_params["v_template"]).unsqueeze(0)
404
+ shape_dirs = torch.FloatTensor(smplx_params["shapedirs"])
405
+ J_regressor = torch.FloatTensor(smplx_params["J_regressor"])
406
+
407
+ v_shaped = v_template + _blend_shapes(beta, shape_dirs)
408
+ J = _vertices2joints(J_regressor, v_shaped).squeeze(0).numpy()
409
+
410
+ parents = smplx_params["kintree_table"][()][0]
411
+ parents[0] = -1
412
+ Translates = J[()].copy()
413
+ Translates[1:] -= J[parents[1:]]
414
+ if not return_template_mesh:
415
+ return Translates
416
+ else:
417
+ return Translates, v_shaped
418
+
419
+
420
+ def _preprocess_smplx(smplx_params, source_anim_data, scale=1, debug=False):
421
+ Translates, v_shaped = _get_offsets_from_beta(
422
+ torch.FloatTensor(source_anim_data["betas"]),
423
+ smplx_params,
424
+ return_template_mesh=True,
425
+ )
426
+
427
+ parents = smplx_params["kintree_table"][()][0]
428
+ parents[0] = -1
429
+
430
+ poses = torch.FloatTensor(source_anim_data["poses"])
431
+ source_LclRotation = angle_axis_to_rotation_matrix(poses).numpy()
432
+ source_LclTranslation = np.tile(Translates, (source_LclRotation.shape[0], 1, 1))
433
+ source_LclTranslation[:, 0] += source_anim_data["trans"]
434
+
435
+ source_skeleton = {
436
+ "parent": parents,
437
+ "LclRotation": source_LclRotation,
438
+ "LclTranslation": source_LclTranslation * scale,
439
+ "Translate": Translates * scale,
440
+ "v_shaped": v_shaped.squeeze(0).numpy() * scale,
441
+ }
442
+ return source_skeleton
443
+
444
+
445
+ def _convert_npz_to_fbx(smplh_params, npz_data, save_fn, fps=30, uv_coords=None, uv_faces=None):
446
+ kintree = smplh_params["kintree_table"][0]
447
+ kintree[0] = -1
448
+
449
+ source_anim_data = {
450
+ "betas": npz_data["betas"],
451
+ "poses": npz_data["poses"].reshape(npz_data["poses"].shape[0], -1, 3),
452
+ "trans": npz_data["trans"],
453
+ }
454
+ source_skeleton = _preprocess_smplx(smplh_params, source_anim_data, scale=100)
455
+ rot = rot_mat2trans_mat(source_skeleton["LclRotation"])
456
+ trans = trans2trans_mat(source_skeleton["LclTranslation"])
457
+ frame_data = np.einsum("Btnk,Btkm ->Btnm", trans, rot)
458
+
459
+ fbxManager = fbx.FbxManager.Create()
460
+ fbxScene = fbx.FbxScene.Create(fbxManager, "")
461
+ timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps)
462
+ fbxScene.GetGlobalSettings().SetTimeMode(timeMode)
463
+
464
+ geometryNode = _addSmplXMesh(
465
+ fbxScene,
466
+ source_skeleton["v_shaped"],
467
+ smplh_params["f"],
468
+ uv_coords=uv_coords,
469
+ uv_faces=uv_faces,
470
+ )
471
+ referenceNode, skeletonNodes = _addSmplXSkeleton(
472
+ fbxManager,
473
+ fbxScene=fbxScene,
474
+ trans=source_skeleton["Translate"],
475
+ joint2num=SMPLH_JOINT2NUM,
476
+ kintree_table=kintree,
477
+ )
478
+
479
+ _addSkiningWeight(fbxScene, smplh_params["weights"], geometryNode, skeletonNodes)
480
+ _storeBindPose(fbxScene, geometryNode)
481
+ _animateSkeleton(
482
+ fbxScene=fbxScene,
483
+ skeletonNodes=skeletonNodes,
484
+ frames=frame_data,
485
+ frameRate=fps,
486
+ )
487
+
488
+ with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f:
489
+ temp_file = tmp_f.name
490
+
491
+ try:
492
+ # Save to temporary location
493
+ _saveScene(temp_file, fbxManager, fbxScene)
494
+ # If successful, copy to final destination
495
+ shutil.copy2(temp_file, save_fn)
496
+ except Exception as e:
497
+ print(f"Error saving FBX file: {e}")
498
+ finally:
499
+ # Remove temporary file
500
+ if os.path.exists(temp_file):
501
+ os.remove(temp_file)
502
+
503
+ # CLEANUP
504
+ fbxManager.Destroy()
505
+ del fbxManager, fbxScene
506
+
507
+
508
+ def _read_uv(obj_template):
509
+ uv_coords = None
510
+ uv_faces = None
511
+ if obj_template and os.path.isfile(obj_template):
512
+ try:
513
+ print("Loading UV coordinates from OBJ template: {}".format(obj_template))
514
+ obj_vertices, uv_coords, obj_faces, uv_faces = _parse_obj_file(obj_template)
515
+ print("Loaded {} UV coordinates and {} UV faces".format(len(uv_coords), len(uv_faces)))
516
+ except Exception as e:
517
+ print("Warning: Failed to load UV coordinates from OBJ file: {}".format(e))
518
+ uv_coords = None
519
+ uv_faces = None
520
+ return uv_coords, uv_faces
521
+
522
+
523
+ class SMPLH2FBX:
524
+ def __init__(
525
+ self,
526
+ obj_template="./assets/smpl_family_models/smplh/textures/male_smplh.obj",
527
+ smplh_model_path="./assets/body_models/smplh/neutral/model.npz",
528
+ ):
529
+ print(f"[{self.__class__.__name__}] Load obj_template: {obj_template}")
530
+ self.uv_coords, self.uv_faces = _read_uv(obj_template)
531
+ print(f"[{self.__class__.__name__}] Load smplh_model_path: {smplh_model_path}")
532
+ self.smplh_params = dict(np.load(smplh_model_path, allow_pickle=True))
533
+
534
+ def convert_npz_to_fbx(self, npz_file, outname, fps=30):
535
+ os.makedirs(os.path.dirname(outname), exist_ok=True)
536
+ if isinstance(npz_file, str) and os.path.isfile(npz_file):
537
+ npz_data = dict(np.load(npz_file, allow_pickle=True))
538
+ else:
539
+ npz_data = npz_file
540
+ _convert_npz_to_fbx(
541
+ self.smplh_params,
542
+ npz_data,
543
+ outname,
544
+ uv_coords=self.uv_coords,
545
+ uv_faces=self.uv_faces,
546
+ )
547
+ return os.path.exists(outname)
548
+
549
+ def convert_params_to_fbx(self, params, outname):
550
+ fps = params.get("mocap_framerate", 30)
551
+ os.makedirs(os.path.dirname(outname), exist_ok=True)
552
+ assert len(params["poses"].shape) == 3, f"poses shape should be (F, 52, 3), but got {params['poses'].shape}"
553
+ assert len(params["betas"].shape) == 2, f"betas shape should be (1, 16), but got {params['betas'].shape}"
554
+ assert len(params["trans"].shape) == 2, f"trans shape should be (1, 3), but got {params['trans'].shape}"
555
+ _convert_npz_to_fbx(
556
+ self.smplh_params,
557
+ params,
558
+ outname,
559
+ fps=fps,
560
+ uv_coords=self.uv_coords,
561
+ uv_faces=self.uv_faces,
562
+ )
563
+ return os.path.exists(outname)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ # python hymotion/utils/smplh2fbx.py
568
+ import argparse
569
+
570
+ parser = argparse.ArgumentParser()
571
+ parser.add_argument("root", type=str)
572
+ args = parser.parse_args()
573
+
574
+ converter = SMPLH2FBX()
575
+
576
+ if os.path.isdir(args.root):
577
+ npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz")))
578
+ else:
579
+ if args.root.endswith(".npz"):
580
+ npzfiles = [args.root]
581
+ else:
582
+ raise ValueError(f"Unknown file type: {args.root}")
583
+
584
+ for npzfile in npzfiles:
585
+ converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx"))
hymotion/utils/smplh2woodfbx.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ from typing import Dict, Optional
6
+
7
+ import fbx
8
+ import numpy as np
9
+ import torch
10
+ from transforms3d.euler import mat2euler
11
+
12
+ from .geometry import angle_axis_to_rotation_matrix
13
+
14
+ # yapf: disable
15
+ SMPLH_JOINT2NUM = {
16
+ "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3,
17
+ "L_Knee": 4, "R_Knee": 5, "Spine2": 6,
18
+ "L_Ankle": 7, "R_Ankle": 8,
19
+ "Spine3": 9,
20
+ "L_Foot": 10, "R_Foot": 11,
21
+ "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15,
22
+ "L_Shoulder": 16, "R_Shoulder": 17,
23
+ "L_Elbow": 18, "R_Elbow": 19,
24
+ "L_Wrist": 20, "R_Wrist": 21,
25
+ "L_Index1": 22, "L_Index2": 23, "L_Index3": 24,
26
+ "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27,
27
+ "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30,
28
+ "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33,
29
+ "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36,
30
+ "R_Index1": 37, "R_Index2": 38, "R_Index3": 39,
31
+ "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42,
32
+ "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45,
33
+ "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48,
34
+ "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51,
35
+ }
36
+
37
+ # Mapping from SMPL-H joint names to lowercase names used in some FBX templates
38
+ SMPLH_TO_LOWERCASE_MAPPING = {
39
+ "Pelvis": "pelvis",
40
+ "L_Hip": "left_hip",
41
+ "R_Hip": "right_hip",
42
+ "Spine1": "spine1",
43
+ "L_Knee": "left_knee",
44
+ "R_Knee": "right_knee",
45
+ "Spine2": "spine2",
46
+ "L_Ankle": "left_ankle",
47
+ "R_Ankle": "right_ankle",
48
+ "Spine3": "spine3",
49
+ "L_Foot": "left_foot",
50
+ "R_Foot": "right_foot",
51
+ "Neck": "neck",
52
+ "L_Collar": "left_collar",
53
+ "R_Collar": "right_collar",
54
+ "Head": "head",
55
+ "L_Shoulder": "left_shoulder",
56
+ "R_Shoulder": "right_shoulder",
57
+ "L_Elbow": "left_elbow",
58
+ "R_Elbow": "right_elbow",
59
+ "L_Wrist": "left_wrist",
60
+ "R_Wrist": "right_wrist",
61
+ "L_Index1": "left_index1",
62
+ "L_Index2": "left_index2",
63
+ "L_Index3": "left_index3",
64
+ "L_Middle1": "left_middle1",
65
+ "L_Middle2": "left_middle2",
66
+ "L_Middle3": "left_middle3",
67
+ "L_Pinky1": "left_pinky1",
68
+ "L_Pinky2": "left_pinky2",
69
+ "L_Pinky3": "left_pinky3",
70
+ "L_Ring1": "left_ring1",
71
+ "L_Ring2": "left_ring2",
72
+ "L_Ring3": "left_ring3",
73
+ "L_Thumb1": "left_thumb1",
74
+ "L_Thumb2": "left_thumb2",
75
+ "L_Thumb3": "left_thumb3",
76
+ "R_Index1": "right_index1",
77
+ "R_Index2": "right_index2",
78
+ "R_Index3": "right_index3",
79
+ "R_Middle1": "right_middle1",
80
+ "R_Middle2": "right_middle2",
81
+ "R_Middle3": "right_middle3",
82
+ "R_Pinky1": "right_pinky1",
83
+ "R_Pinky2": "right_pinky2",
84
+ "R_Pinky3": "right_pinky3",
85
+ "R_Ring1": "right_ring1",
86
+ "R_Ring2": "right_ring2",
87
+ "R_Ring3": "right_ring3",
88
+ "R_Thumb1": "right_thumb1",
89
+ "R_Thumb2": "right_thumb2",
90
+ "R_Thumb3": "right_thumb3",
91
+ }
92
+ # yapf: enable
93
+
94
+
95
+ def _loadFbxScene(fbxManager, filepath):
96
+ """Load an FBX file into a scene"""
97
+ importer = fbx.FbxImporter.Create(fbxManager, "")
98
+
99
+ if not importer.Initialize(filepath, -1, fbxManager.GetIOSettings()):
100
+ raise Exception(
101
+ f"Failed to initialize FBX importer for: {filepath}\nError: {importer.GetStatus().GetErrorString()}"
102
+ )
103
+
104
+ fbxScene = fbx.FbxScene.Create(fbxManager, "")
105
+ importer.Import(fbxScene)
106
+ importer.Destroy()
107
+
108
+ return fbxScene
109
+
110
+
111
+ def _collectAllNodes(node, nodes_dict=None):
112
+ """Recursively collect all nodes in the scene hierarchy"""
113
+ if nodes_dict is None:
114
+ nodes_dict = {}
115
+
116
+ nodes_dict[node.GetName()] = node
117
+
118
+ for i in range(node.GetChildCount()):
119
+ _collectAllNodes(node.GetChild(i), nodes_dict)
120
+
121
+ return nodes_dict
122
+
123
+
124
+ def _collectSkeletonNodes(node, skeleton_nodes=None):
125
+ """Recursively collect skeleton/bone nodes"""
126
+ if skeleton_nodes is None:
127
+ skeleton_nodes = {}
128
+
129
+ # Check if this node has a skeleton attribute
130
+ attr = node.GetNodeAttribute()
131
+ if attr and attr.GetAttributeType() == fbx.FbxNodeAttribute.EType.eSkeleton:
132
+ skeleton_nodes[node.GetName()] = node
133
+
134
+ for i in range(node.GetChildCount()):
135
+ _collectSkeletonNodes(node.GetChild(i), skeleton_nodes)
136
+
137
+ return skeleton_nodes
138
+
139
+
140
+ def _animateSingleChannel(animLayer, component, name, values, frameDuration):
141
+ """Animate a single channel (X, Y, or Z) with keyframes"""
142
+ ncomp = {"X": 0, "Y": 1, "Z": 2}.get(name, 0)
143
+
144
+ time = fbx.FbxTime()
145
+ curve = component.GetCurve(animLayer, name, True)
146
+ curve.KeyModifyBegin()
147
+ for nth in range(len(values)):
148
+ time.SetSecondDouble(nth * frameDuration)
149
+ keyIndex = curve.KeyAdd(time)[0]
150
+ curve.KeySetValue(keyIndex, values[nth][ncomp])
151
+ curve.KeySetInterpolation(keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant)
152
+ curve.KeyModifyEnd()
153
+
154
+
155
+ def _animateRotationKeyFrames(animLayer, node, rot_matrices, frameDuration):
156
+ """Animate rotation keyframes for a node using rotation matrices"""
157
+ rotations = []
158
+ for nth in range(len(rot_matrices)):
159
+ # Convert rotation matrix to Euler angles (XYZ order)
160
+ euler = np.rad2deg(mat2euler(rot_matrices[nth], axes="sxyz"))
161
+ rotations.append(euler)
162
+
163
+ _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration)
164
+ _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration)
165
+ _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration)
166
+
167
+
168
+ def _animateTranslationKeyFrames(animLayer, node, translations, frameDuration):
169
+ """Animate translation keyframes for a node"""
170
+ # Ensure translations is a numpy array with shape (num_frames, 3)
171
+ if isinstance(translations, torch.Tensor):
172
+ translations = translations.numpy()
173
+ translations = np.asarray(translations, dtype=np.float64)
174
+
175
+ if len(translations.shape) == 1:
176
+ # Single frame, reshape to (1, 3)
177
+ translations = translations.reshape(1, -1)
178
+
179
+ _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration)
180
+ _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration)
181
+ _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration)
182
+
183
+
184
+ def _clearExistingAnimations(fbxScene):
185
+ """Remove all existing animation stacks from the scene"""
186
+ anim_stack_count = fbxScene.GetSrcObjectCount(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId))
187
+ for i in range(anim_stack_count - 1, -1, -1):
188
+ anim_stack = fbxScene.GetSrcObject(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId), i)
189
+ if anim_stack:
190
+ anim_stack.Destroy()
191
+
192
+
193
+ def _applyAnimationToSkeleton(fbxScene, nodes_map, rot_matrices, translations, fps, smplh_to_fbx_mapping, name="Take1"):
194
+ """
195
+ Apply SMPL-H animation data to skeleton nodes in the FBX scene.
196
+
197
+ Args:
198
+ fbxScene: FBX scene object
199
+ nodes_map: Dictionary of node_name -> FbxNode
200
+ rot_matrices: (num_frames, num_joints, 3, 3) rotation matrices
201
+ translations: (num_frames, 3) root translations (relative displacement, not absolute position)
202
+ fps: Frame rate
203
+ smplh_to_fbx_mapping: Mapping from SMPL-H joint names to FBX node names
204
+ name: Animation take name
205
+ """
206
+ frameDuration = 1.0 / fps
207
+ num_frames = rot_matrices.shape[0]
208
+ num_joints = rot_matrices.shape[1]
209
+
210
+ # Create animation stack and layer
211
+ animStack = fbx.FbxAnimStack.Create(fbxScene, name)
212
+ animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer")
213
+ animStack.AddMember(animLayer)
214
+
215
+ # Track if root translation was applied
216
+ root_translation_applied = False
217
+ root_node = None
218
+
219
+ # Get root node's initial LclTranslation from template (this is like Translates[0] in smplh2woodfbx.py)
220
+ root_initial_translation = None
221
+ root_fbx_name = smplh_to_fbx_mapping.get("Pelvis")
222
+ if root_fbx_name and root_fbx_name in nodes_map:
223
+ root_node_temp = nodes_map[root_fbx_name]
224
+ initial_trans = root_node_temp.LclTranslation.Get()
225
+ root_initial_translation = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
226
+ print(f"Root initial LclTranslation from template: {root_initial_translation}")
227
+
228
+ # Animate each joint
229
+ for smplh_joint_name, smplh_joint_idx in SMPLH_JOINT2NUM.items():
230
+ if smplh_joint_idx >= num_joints:
231
+ continue
232
+
233
+ # Get the FBX node name from mapping
234
+ fbx_node_name = smplh_to_fbx_mapping.get(smplh_joint_name)
235
+ if not fbx_node_name:
236
+ if smplh_joint_idx == 0:
237
+ print(f"Warning: Root joint 'Pelvis' not found in mapping!")
238
+ continue
239
+
240
+ # Find the node
241
+ node = nodes_map.get(fbx_node_name)
242
+ if not node:
243
+ print(f"Warning: Joint '{smplh_joint_name}' (FBX: '{fbx_node_name}') not found in scene")
244
+ continue
245
+
246
+ # Animate rotation
247
+ _animateRotationKeyFrames(
248
+ animLayer=animLayer,
249
+ node=node,
250
+ rot_matrices=rot_matrices[:, smplh_joint_idx],
251
+ frameDuration=frameDuration,
252
+ )
253
+
254
+ # Animate translation for root joint (Pelvis)
255
+ if smplh_joint_idx == 0:
256
+ root_node = node
257
+ # Add initial offset to translations (like smplh2woodfbx.py does: Translates[0] + trans)
258
+ # The translations input is relative displacement, we need to add the template's initial position
259
+ if root_initial_translation is not None:
260
+ final_translations = translations + root_initial_translation
261
+ print(
262
+ f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
263
+ f"initial_offset={root_initial_translation}, "
264
+ f"final translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
265
+ )
266
+ else:
267
+ final_translations = translations
268
+ print(
269
+ f"Applying root translation to '{fbx_node_name}', frames={num_frames}, "
270
+ f"translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}"
271
+ )
272
+ _animateTranslationKeyFrames(
273
+ animLayer=animLayer,
274
+ node=node,
275
+ translations=final_translations,
276
+ frameDuration=frameDuration,
277
+ )
278
+ root_translation_applied = True
279
+
280
+ # If root translation was not applied, try to find root node by common names
281
+ if not root_translation_applied:
282
+ print("Warning: Root translation was not applied through normal mapping, trying fallback...")
283
+ root_candidates = ["Pelvis", "pelvis", "Hips", "hips", "Root", "root", "mixamorig:Hips"]
284
+ for candidate in root_candidates:
285
+ if candidate in nodes_map:
286
+ root_node = nodes_map[candidate]
287
+ # Get initial translation for fallback node
288
+ initial_trans = root_node.LclTranslation.Get()
289
+ fallback_initial = np.array([initial_trans[0], initial_trans[1], initial_trans[2]])
290
+ final_translations = translations + fallback_initial
291
+ print(
292
+ f"Found root node by fallback: '{candidate}', initial_offset={fallback_initial}, applying translation..."
293
+ )
294
+ _animateTranslationKeyFrames(
295
+ animLayer=animLayer,
296
+ node=root_node,
297
+ translations=final_translations,
298
+ frameDuration=frameDuration,
299
+ )
300
+ root_translation_applied = True
301
+ break
302
+
303
+ if not root_translation_applied:
304
+ print("ERROR: Could not find root node to apply translation!")
305
+ print(f"Available nodes: {list(nodes_map.keys())}")
306
+
307
+ return animStack
308
+
309
+
310
+ def _saveScene(filename, fbxManager, fbxScene, embed_textures=True):
311
+ """Save the FBX scene to a file
312
+
313
+ Args:
314
+ filename: Output file path
315
+ fbxManager: FBX manager instance
316
+ fbxScene: FBX scene to save
317
+ embed_textures: Whether to embed textures/media in the FBX file (default True)
318
+ """
319
+ # Configure IOSettings to embed textures/media
320
+ ios = fbxManager.GetIOSettings()
321
+ if embed_textures:
322
+ ios.SetBoolProp(fbx.EXP_FBX_EMBEDDED, True)
323
+ ios.SetBoolProp(fbx.EXP_FBX_MATERIAL, True)
324
+ ios.SetBoolProp(fbx.EXP_FBX_TEXTURE, True)
325
+
326
+ exporter = fbx.FbxExporter.Create(fbxManager, "")
327
+ isInitialized = exporter.Initialize(filename, -1, ios)
328
+
329
+ if isInitialized is False:
330
+ raise Exception(f"Exporter failed to initialize. Error: {exporter.GetStatus().GetErrorString()}")
331
+
332
+ exporter.Export(fbxScene)
333
+ exporter.Destroy()
334
+
335
+
336
+ def _convert_smplh_to_woodfbx(
337
+ template_fbx_path,
338
+ npz_data,
339
+ save_fn,
340
+ fps=30,
341
+ scale=100,
342
+ smplh_to_fbx_mapping=None,
343
+ clear_animations=True,
344
+ ):
345
+ """
346
+ Convert SMPL-H parameters to FBX using a template FBX file.
347
+ The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
348
+
349
+ Args:
350
+ template_fbx_path: Path to the template FBX file (e.g., boy_Rigging_smplx.fbx)
351
+ npz_data: Dictionary containing SMPL-H parameters
352
+ - poses: (num_frames, 52, 3) or (num_frames, 156)
353
+ - trans: (num_frames, 3)
354
+ save_fn: Output FBX file path
355
+ fps: Frame rate
356
+ scale: Scale factor for translation (default 100 for m to cm conversion)
357
+ smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
358
+ clear_animations: Whether to clear existing animations in the template
359
+
360
+ Returns:
361
+ bool: True if successful
362
+ """
363
+ # Prepare poses data
364
+ poses = npz_data["poses"]
365
+ if isinstance(poses, np.ndarray):
366
+ poses = torch.from_numpy(poses).float()
367
+
368
+ if len(poses.shape) == 2:
369
+ # (num_frames, 156) -> (num_frames, 52, 3)
370
+ poses = poses.reshape(poses.shape[0], -1, 3)
371
+
372
+ # Convert axis-angle to rotation matrices: (num_frames, num_joints, 3, 3)
373
+ rot_matrices = angle_axis_to_rotation_matrix(poses).numpy()
374
+
375
+ # Prepare translation data
376
+ trans = npz_data["trans"]
377
+ if isinstance(trans, torch.Tensor):
378
+ trans = trans.numpy()
379
+
380
+ # Apply scale to translation
381
+ translations = trans * scale
382
+
383
+ # Create FBX manager and load template
384
+ fbxManager = fbx.FbxManager.Create()
385
+ ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
386
+ fbxManager.SetIOSettings(ios)
387
+
388
+ print(f"Loading FBX template: {template_fbx_path}")
389
+ fbxScene = _loadFbxScene(fbxManager, template_fbx_path)
390
+
391
+ # Set time mode
392
+ timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps)
393
+ fbxScene.GetGlobalSettings().SetTimeMode(timeMode)
394
+
395
+ # Collect all nodes
396
+ rootNode = fbxScene.GetRootNode()
397
+ all_nodes = _collectAllNodes(rootNode)
398
+ skeleton_nodes = _collectSkeletonNodes(rootNode)
399
+
400
+ print(f"Found {len(all_nodes)} nodes in scene")
401
+ print(f"Found {len(skeleton_nodes)} skeleton nodes: {list(skeleton_nodes.keys())}")
402
+
403
+ # Use default mapping if not provided
404
+ if smplh_to_fbx_mapping is None:
405
+ smplh_to_fbx_mapping = _auto_detect_mapping(all_nodes)
406
+ print(f"Auto-detected {len(smplh_to_fbx_mapping)} joint mappings")
407
+ if "Pelvis" in smplh_to_fbx_mapping:
408
+ print(f" Root joint 'Pelvis' mapped to: '{smplh_to_fbx_mapping['Pelvis']}'")
409
+ else:
410
+ print(f" WARNING: Root joint 'Pelvis' not found in mapping!")
411
+ print(f" Available nodes: {list(all_nodes.keys())[:20]}...") # Show first 20 nodes
412
+
413
+ # Clear existing animations if requested
414
+ if clear_animations:
415
+ _clearExistingAnimations(fbxScene)
416
+
417
+ # Apply animation to skeleton
418
+ _applyAnimationToSkeleton(
419
+ fbxScene=fbxScene,
420
+ nodes_map=all_nodes,
421
+ rot_matrices=rot_matrices,
422
+ translations=translations,
423
+ fps=fps,
424
+ smplh_to_fbx_mapping=smplh_to_fbx_mapping,
425
+ name="SMPLH_Animation",
426
+ )
427
+
428
+ # Save to temporary file first, then copy to final destination
429
+ os.makedirs(os.path.dirname(save_fn) if os.path.dirname(save_fn) else ".", exist_ok=True)
430
+ with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f:
431
+ temp_file = tmp_f.name
432
+
433
+ try:
434
+ _saveScene(temp_file, fbxManager, fbxScene)
435
+ shutil.copy2(temp_file, save_fn)
436
+ os.remove(temp_file)
437
+ print(f"Successfully saved FBX to: {save_fn}")
438
+ except Exception as e:
439
+ print(f"Error saving FBX file: {e}")
440
+ return False
441
+ finally:
442
+ fbxManager.Destroy()
443
+ del fbxManager, fbxScene
444
+
445
+ return os.path.exists(save_fn)
446
+
447
+
448
+ def _auto_detect_mapping(all_nodes):
449
+ """Auto-detect the mapping from SMPL-H joints to FBX nodes"""
450
+ mapping = {}
451
+ for smplh_name in SMPLH_JOINT2NUM.keys():
452
+ # Try exact match
453
+ if smplh_name in all_nodes:
454
+ mapping[smplh_name] = smplh_name
455
+ # Try lowercase version
456
+ elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in all_nodes:
457
+ mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
458
+ return mapping
459
+
460
+
461
+ class SMPLH2WoodFBX:
462
+ """
463
+ Class to convert SMPL-H parameters to FBX using a template FBX file.
464
+ The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters.
465
+ No SMPL-H model assets (model.npz) required.
466
+
467
+ Example usage:
468
+ converter = SMPLH2WoodFBX(
469
+ template_fbx_path="./assets/wooden_models/boy_Rigging_smplx.fbx"
470
+ )
471
+
472
+ # From npz file
473
+ converter.convert_npz_to_fbx("motion.npz", "output.fbx", fps=30)
474
+
475
+ # From parameters dict
476
+ params = {
477
+ "poses": poses_array, # (num_frames, 52, 3) or (num_frames, 156)
478
+ "trans": trans_array, # (num_frames, 3)
479
+ }
480
+ converter.convert_params_to_fbx(params, "output.fbx")
481
+ """
482
+
483
+ def __init__(
484
+ self,
485
+ template_fbx_path: str = "./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
486
+ smplh_to_fbx_mapping: Optional[Dict[str, str]] = None,
487
+ scale: float = 100,
488
+ ):
489
+ """
490
+ Initialize the converter.
491
+
492
+ Args:
493
+ template_fbx_path: Path to the template FBX file
494
+ smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names
495
+ scale: Scale factor for translation (default 100 for m to cm conversion)
496
+ """
497
+ print(f"[{self.__class__.__name__}] Template FBX: {template_fbx_path}")
498
+ self.template_fbx_path = template_fbx_path
499
+ self.smplh_to_fbx_mapping = smplh_to_fbx_mapping
500
+ self.scale = scale
501
+
502
+ # Analyze template FBX to detect joint names
503
+ self._analyze_template()
504
+
505
+ def _analyze_template(self):
506
+ """Analyze the template FBX file to detect available skeleton nodes"""
507
+ fbxManager = fbx.FbxManager.Create()
508
+ ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT)
509
+ fbxManager.SetIOSettings(ios)
510
+
511
+ try:
512
+ fbxScene = _loadFbxScene(fbxManager, self.template_fbx_path)
513
+ rootNode = fbxScene.GetRootNode()
514
+
515
+ self.all_template_nodes = list(_collectAllNodes(rootNode).keys())
516
+ self.skeleton_template_nodes = list(_collectSkeletonNodes(rootNode).keys())
517
+
518
+ print(f"[{self.__class__.__name__}] Template nodes: {len(self.all_template_nodes)}")
519
+ print(f"[{self.__class__.__name__}] Skeleton nodes: {self.skeleton_template_nodes}")
520
+
521
+ # Auto-detect mapping if not provided
522
+ if self.smplh_to_fbx_mapping is None:
523
+ self.smplh_to_fbx_mapping = self._auto_detect_mapping()
524
+ print(f"[{self.__class__.__name__}] Auto-detected {len(self.smplh_to_fbx_mapping)} joint mappings")
525
+ finally:
526
+ fbxManager.Destroy()
527
+
528
+ def _auto_detect_mapping(self):
529
+ """Auto-detect the mapping from SMPL-H joints to FBX nodes"""
530
+ mapping = {}
531
+ for smplh_name in SMPLH_JOINT2NUM.keys():
532
+ # Try exact match
533
+ if smplh_name in self.all_template_nodes:
534
+ mapping[smplh_name] = smplh_name
535
+ # Try lowercase version
536
+ elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in self.all_template_nodes:
537
+ mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name]
538
+ return mapping
539
+
540
+ def convert_npz_to_fbx(self, npz_file, outname, fps=30, clear_animations=True):
541
+ """
542
+ Convert an npz file containing SMPL-H parameters to FBX.
543
+
544
+ Args:
545
+ npz_file: Path to the npz file or dict containing SMPL-H parameters
546
+ outname: Output FBX file path
547
+ fps: Frame rate
548
+ clear_animations: Whether to clear existing animations in template
549
+
550
+ Returns:
551
+ bool: True if successful
552
+ """
553
+ os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
554
+
555
+ if isinstance(npz_file, str) and os.path.isfile(npz_file):
556
+ npz_data = dict(np.load(npz_file, allow_pickle=True))
557
+ else:
558
+ npz_data = npz_file
559
+
560
+ return _convert_smplh_to_woodfbx(
561
+ template_fbx_path=self.template_fbx_path,
562
+ npz_data=npz_data,
563
+ save_fn=outname,
564
+ fps=fps,
565
+ scale=self.scale,
566
+ smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
567
+ clear_animations=clear_animations,
568
+ )
569
+
570
+ def convert_params_to_fbx(self, params, outname, clear_animations=True):
571
+ """
572
+ Convert SMPL-H parameters to FBX.
573
+
574
+ Args:
575
+ params: Dictionary containing SMPL-H parameters
576
+ - poses: (num_frames, 52, 3) or (num_frames, 156)
577
+ - trans: (num_frames, 3)
578
+ - mocap_framerate (optional): Frame rate
579
+ outname: Output FBX file path
580
+ clear_animations: Whether to clear existing animations in template
581
+
582
+ Returns:
583
+ bool: True if successful
584
+ """
585
+ fps = params.get("mocap_framerate", 30)
586
+ os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True)
587
+
588
+ npz_data = {
589
+ "poses": params["poses"],
590
+ "trans": params["trans"],
591
+ }
592
+
593
+ return _convert_smplh_to_woodfbx(
594
+ template_fbx_path=self.template_fbx_path,
595
+ npz_data=npz_data,
596
+ save_fn=outname,
597
+ fps=fps,
598
+ scale=self.scale,
599
+ smplh_to_fbx_mapping=self.smplh_to_fbx_mapping,
600
+ clear_animations=clear_animations,
601
+ )
602
+
603
+
604
+ LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909,
605
+ -0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192,
606
+ -0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069,
607
+ -0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579,
608
+ -0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998,
609
+ 0.0266, 0.0529, 0.5356, 0.0460, -0.2774]
610
+ RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909,
611
+ 0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192,
612
+ 0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069,
613
+ -0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579,
614
+ 0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998,
615
+ -0.0266, -0.0529, 0.5356, -0.0460, 0.2774]
616
+
617
+ def construct_smpl_data_dict(
618
+ rot6d,
619
+ transl,
620
+ betas=None,
621
+ gender="neutral",
622
+ use_default_hand_mean_pose=False,
623
+ ) -> dict:
624
+ rotation_matrix = rot6d_to_rotation_matrix(rot6d)
625
+ angle_axis = rotation_matrix_to_angle_axis(rotation_matrix)
626
+ left_hand_mean_pose = (
627
+ torch.tensor(
628
+ LEFT_HAND_MEAN_AA,
629
+ device=angle_axis.device,
630
+ dtype=angle_axis.dtype,
631
+ )
632
+ .unsqueeze(0)
633
+ .repeat(angle_axis.shape[0], 1)
634
+ .reshape(angle_axis.shape[0], -1, 3)
635
+ )
636
+ right_hand_mean_pose = (
637
+ torch.tensor(
638
+ RIGHT_HAND_MEAN_AA,
639
+ device=angle_axis.device,
640
+ dtype=angle_axis.dtype,
641
+ )
642
+ .unsqueeze(0)
643
+ .repeat(angle_axis.shape[0], 1)
644
+ .reshape(angle_axis.shape[0], -1, 3)
645
+ )
646
+ if angle_axis.shape[1] == 22:
647
+ angle_axis = torch.cat(
648
+ [
649
+ angle_axis,
650
+ left_hand_mean_pose,
651
+ right_hand_mean_pose,
652
+ ],
653
+ dim=1,
654
+ )
655
+ elif angle_axis.shape[1] == 52:
656
+ if use_default_hand_mean_pose:
657
+ angle_axis = torch.cat(
658
+ [
659
+ angle_axis[:, :22],
660
+ left_hand_mean_pose,
661
+ right_hand_mean_pose,
662
+ ],
663
+ dim=1,
664
+ )
665
+ else:
666
+ angle_axis = angle_axis
667
+
668
+ assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}"
669
+ dump = {
670
+ "betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)),
671
+ "gender": gender,
672
+ "poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1),
673
+ "trans": transl.cpu().numpy(),
674
+ "mocap_framerate": 30,
675
+ "num_frames": angle_axis.shape[0],
676
+ "Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3],
677
+ }
678
+ return dump
679
+
680
+ if __name__ == "__main__":
681
+ # python hymotion/utils/smplh2woodfbx.py
682
+ import argparse
683
+
684
+ parser = argparse.ArgumentParser()
685
+ parser.add_argument("root", type=str)
686
+ args = parser.parse_args()
687
+
688
+ converter = SMPLH2WoodFBX(
689
+ template_fbx_path="./assets/wooden_models/boy_Rigging_smplx_tex.fbx",
690
+ scale=100,
691
+ )
692
+
693
+ if os.path.isdir(args.root):
694
+ npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz")))
695
+ else:
696
+ if args.root.endswith(".npz"):
697
+ npzfiles = [args.root]
698
+ else:
699
+ raise ValueError(f"Unknown file type: {args.root}")
700
+
701
+ for npzfile in npzfiles:
702
+ converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx"))
hymotion/utils/t2m_runtime.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # t2m_runtime.py
2
+ import os
3
+ import threading
4
+ import time
5
+ import uuid
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import yaml
10
+
11
+ from ..prompt_engineering.prompt_rewrite import PromptRewriter
12
+ from .loaders import load_object
13
+ from .visualize_mesh_web import save_visualization_data
14
+
15
+ try:
16
+ import fbx
17
+
18
+ FBX_AVAILABLE = True
19
+ print(">>> FBX module found.")
20
+ except ImportError:
21
+ FBX_AVAILABLE = False
22
+ print(">>> FBX module not found.")
23
+
24
+
25
+ def _get_local_ip():
26
+ import subprocess
27
+
28
+ result = subprocess.run(["hostname", "-I"], capture_output=True, text=True, timeout=5)
29
+ if result.returncode == 0:
30
+ for ip in result.stdout.strip().split():
31
+ if not ip.startswith("127.") and not ip.startswith("172.17."):
32
+ return ip
33
+ return "localhost"
34
+
35
+
36
+ def _now():
37
+ t = time.time()
38
+ ms = int((t - int(t)) * 1000)
39
+ return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
40
+
41
+
42
+ class T2MRuntime:
43
+ def __init__(
44
+ self,
45
+ config_path: str,
46
+ ckpt_name: str = "latest.ckpt",
47
+ skip_text: bool = False,
48
+ device_ids: Union[list[int], None] = None,
49
+ prompt_engineering_host: Optional[str] = None,
50
+ skip_model_loading: bool = False,
51
+ ):
52
+ self.config_path = config_path
53
+ self.ckpt_name = ckpt_name
54
+ self.skip_text = skip_text
55
+ self.prompt_engineering_host = prompt_engineering_host
56
+ self.skip_model_loading = skip_model_loading
57
+ self.local_ip = _get_local_ip()
58
+
59
+ # Check for CPU-only mode via environment variable
60
+ # Set HY_MOTION_DEVICE=cpu to force CPU mode
61
+ force_cpu = os.environ.get("HY_MOTION_DEVICE", "").lower() == "cpu"
62
+ if force_cpu:
63
+ print(">>> [INFO] CPU mode enabled via HY_MOTION_DEVICE=cpu environment variable")
64
+ self.device_ids = []
65
+ elif torch.cuda.is_available():
66
+ all_ids = list(range(torch.cuda.device_count()))
67
+ self.device_ids = all_ids if device_ids is None else [i for i in device_ids if i in all_ids]
68
+ else:
69
+ self.device_ids = []
70
+
71
+ self.pipelines = []
72
+ self._gpu_load = []
73
+ self._lock = threading.Lock()
74
+ self._loaded = False
75
+
76
+ self.prompt_rewriter = PromptRewriter(backend="our_rewriter", host=self.prompt_engineering_host)
77
+
78
+ # Skip model loading if checkpoint not found
79
+ if self.skip_model_loading:
80
+ print(">>> [WARNING] Model loading skipped - checkpoint not found")
81
+ self._loaded = True # Mark as loaded to prevent further load attempts
82
+ else:
83
+ self.load()
84
+ self.fbx_available = FBX_AVAILABLE
85
+ if self.fbx_available:
86
+ try:
87
+ from .smplh2woodfbx import SMPLH2WoodFBX
88
+
89
+ self.fbx_converter = SMPLH2WoodFBX()
90
+ except Exception as e:
91
+ print(f">>> Failed to initialize FBX converter: {e}")
92
+ self.fbx_available = False
93
+ self.fbx_converter = None
94
+ else:
95
+ self.fbx_converter = None
96
+ print(">>> FBX module not found. FBX export will be disabled.")
97
+
98
+ device_info = self.device_ids if self.device_ids else 'cpu'
99
+ if self.skip_model_loading:
100
+ print(f">>> T2MRuntime initialized (model NOT loaded) in IP {self.local_ip}, devices={device_info}")
101
+ else:
102
+ print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}")
103
+
104
+ def load(self):
105
+ if self._loaded:
106
+ return
107
+ print(f">>> Loading model from {self.config_path}...")
108
+
109
+ with open(self.config_path, "r") as f:
110
+ config = yaml.load(f, Loader=yaml.FullLoader)
111
+
112
+ if not self.device_ids:
113
+ pipeline = load_object(
114
+ config["train_pipeline"],
115
+ config["train_pipeline_args"],
116
+ network_module=config["network_module"],
117
+ network_module_args=config["network_module_args"],
118
+ )
119
+ device = torch.device("cpu")
120
+ pipeline.load_in_demo(
121
+ self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text
122
+ )
123
+ pipeline.to(device)
124
+ self.pipelines = [pipeline]
125
+ self._gpu_load = [0]
126
+ else:
127
+ for gid in self.device_ids:
128
+ p = load_object(
129
+ config["train_pipeline"],
130
+ config["train_pipeline_args"],
131
+ network_module=config["network_module"],
132
+ network_module_args=config["network_module_args"],
133
+ )
134
+ p.load_in_demo(self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text)
135
+ p.to(torch.device(f"cuda:{gid}"))
136
+ self.pipelines.append(p)
137
+ self._gpu_load = [0] * len(self.pipelines)
138
+
139
+ self._loaded = True
140
+
141
+ def _acquire_pipeline(self) -> int:
142
+ while True:
143
+ with self._lock:
144
+ for i in range(len(self._gpu_load)):
145
+ if self._gpu_load[i] == 0:
146
+ self._gpu_load[i] = 1
147
+ return i
148
+ time.sleep(0.01)
149
+
150
+ def _release_pipeline(self, idx: int):
151
+ with self._lock:
152
+ self._gpu_load[idx] = 0
153
+
154
+ def test_dit_inference(self, duration: float = 2.0, seed: int = 42) -> bool:
155
+ """
156
+ Test DiT model inference with unconditional/blank input.
157
+ This method is used to verify the DiT model works before loading text encoder.
158
+
159
+ Args:
160
+ duration: Duration of the test motion in seconds
161
+ seed: Random seed for reproducibility
162
+
163
+ Returns:
164
+ True if inference succeeds and produces valid output
165
+ """
166
+ if not self.pipelines:
167
+ raise RuntimeError("No pipeline loaded. Call load() first.")
168
+
169
+ pi = self._acquire_pipeline()
170
+ try:
171
+ pipeline = self.pipelines[pi]
172
+ pipeline.eval()
173
+ device = next(pipeline.parameters()).device
174
+
175
+ # Calculate frame length from duration (assuming 30fps output, 20fps internal)
176
+ length = int(duration * 20)
177
+ length = min(length, pipeline.train_frames)
178
+
179
+ # Use null features for unconditional generation
180
+ batch_size = 1
181
+ vtxt_input = pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device)
182
+ ctxt_input = pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device)
183
+ ctxt_length = torch.tensor([1] * batch_size, device=device)
184
+
185
+ # Create masks
186
+ from ..pipeline.motion_diffusion import length_to_mask
187
+
188
+ ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1])
189
+ x_length = torch.LongTensor([length] * batch_size).to(device)
190
+ x_mask_temporal = length_to_mask(x_length, pipeline.train_frames)
191
+
192
+ # Run denoising inference
193
+ print(f"\t>>> Running DiT inference test: length={length}, device={device}")
194
+
195
+ # Create random noise
196
+ generator = torch.Generator(device=device).manual_seed(seed)
197
+ latent_shape = (batch_size, pipeline.train_frames, pipeline.mean.shape[-1])
198
+ latents = torch.randn(latent_shape, generator=generator, device=device, dtype=vtxt_input.dtype)
199
+
200
+ # Simple single-step denoising test (just forward pass)
201
+ with torch.no_grad():
202
+ # Get timestep
203
+ timesteps = torch.tensor([0.5], device=device, dtype=vtxt_input.dtype).expand(batch_size)
204
+
205
+ # Forward pass through DiT
206
+ # Use correct parameter names for HunyuanMotionMMDiT.forward()
207
+ _ = pipeline.motion_transformer(
208
+ x=latents,
209
+ ctxt_input=ctxt_input,
210
+ vtxt_input=vtxt_input,
211
+ timesteps=timesteps,
212
+ x_mask_temporal=x_mask_temporal,
213
+ ctxt_mask_temporal=ctxt_mask_temporal,
214
+ )
215
+
216
+ print(f"\t>>> DiT forward pass completed successfully!")
217
+ return True
218
+
219
+ except Exception as e:
220
+ print(f"\t>>> DiT inference test failed: {e}")
221
+ raise
222
+ finally:
223
+ self._release_pipeline(pi)
224
+
225
+ def load_text_encoder(self) -> None:
226
+ """
227
+ Load text encoder for all pipelines.
228
+ This is called after DiT model testing to complete the initialization.
229
+ """
230
+ if not self.pipelines:
231
+ raise RuntimeError("No pipeline loaded. Call load() first.")
232
+
233
+ print(">>> Loading text encoder for all pipelines...")
234
+ for i, pipeline in enumerate(self.pipelines):
235
+ if not hasattr(pipeline, "text_encoder") or pipeline.text_encoder is None:
236
+ device = next(pipeline.parameters()).device
237
+ pipeline.text_encoder = load_object(pipeline._text_encoder_module, pipeline._text_encoder_cfg)
238
+ pipeline.text_encoder.to(device)
239
+ print(f"\t>>> Text encoder loaded for pipeline {i} on {device}")
240
+
241
+ # Update skip_text flag
242
+ self.skip_text = False
243
+ print(">>> Text encoder loading completed!")
244
+
245
+ def rewrite_text_and_infer_time(self, text: str) -> Tuple[float, str]:
246
+ print("Start rewriting text...")
247
+ duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}")
248
+ print(f"\t>>> Rewritten text: {rewritten_text}, duration: {duration:.2f} seconds")
249
+ return duration, rewritten_text
250
+
251
+ def generate_motion(
252
+ self,
253
+ text: str,
254
+ seeds_csv: str,
255
+ duration: float,
256
+ cfg_scale: float,
257
+ output_format: str = "fbx",
258
+ output_dir: Optional[str] = None,
259
+ output_filename: Optional[str] = None,
260
+ original_text: Optional[str] = None,
261
+ use_special_game_feat: bool = False,
262
+ ) -> Tuple[Union[str, list[str]], dict]:
263
+ # Check if model was skipped due to missing checkpoint
264
+ if self.skip_model_loading:
265
+ raise RuntimeError(
266
+ "Motion generation is not available: model checkpoint was not found. "
267
+ "Please ensure the checkpoint file exists at the specified path."
268
+ )
269
+
270
+ self.load()
271
+ seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""]
272
+ pi = self._acquire_pipeline()
273
+ try:
274
+ pipeline = self.pipelines[pi]
275
+ pipeline.eval()
276
+
277
+ # When skip_text=True (debug mode), use blank text features
278
+ if self.skip_text:
279
+ print(">>> [Debug Mode] Using blank text features (skip_text=True)")
280
+ device = next(pipeline.parameters()).device
281
+ batch_size = len(seeds) if seeds else 1
282
+ # Create blank hidden_state_dict using null features
283
+ hidden_state_dict = {
284
+ "text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device),
285
+ "text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device),
286
+ "text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device),
287
+ }
288
+ # Disable CFG in debug mode (use cfg_scale=1.0)
289
+ model_output = pipeline.generate(
290
+ text,
291
+ seeds,
292
+ duration,
293
+ cfg_scale=1.0,
294
+ use_special_game_feat=False,
295
+ hidden_state_dict=hidden_state_dict,
296
+ )
297
+ else:
298
+ model_output = pipeline.generate(
299
+ text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat
300
+ )
301
+ finally:
302
+ self._release_pipeline(pi)
303
+
304
+ ts = _now()
305
+ save_data, base_filename = save_visualization_data(
306
+ output=model_output,
307
+ text=text if original_text is None else original_text,
308
+ rewritten_text=text,
309
+ timestamp=ts,
310
+ output_dir=output_dir,
311
+ output_filename=output_filename,
312
+ )
313
+
314
+ view_url = self._generate_html_view_url(
315
+ timestamp=ts,
316
+ file_path=base_filename,
317
+ output_dir=output_dir,
318
+ )
319
+
320
+ if output_format == "fbx" and not self.fbx_available:
321
+ print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to html.")
322
+ output_format = "html"
323
+
324
+ if output_format == "fbx" and self.fbx_available:
325
+ fbx_files = self._generate_fbx_files(
326
+ visualization_data=save_data,
327
+ output_dir=output_dir,
328
+ fbx_filename=output_filename,
329
+ )
330
+ return view_url, fbx_files, model_output
331
+ else:
332
+ raise ValueError(f">>> Invalid output format: {output_format}")
333
+
334
+ def _generate_html_view_url(
335
+ self,
336
+ timestamp: str,
337
+ file_path: str,
338
+ output_dir: Optional[str] = None,
339
+ ) -> str:
340
+ print(f">>> HTML ready, timestamp: {timestamp}")
341
+ gradio_dir = output_dir if output_dir is not None else "output/gradio"
342
+ view_url = f"/view/{gradio_dir}/{file_path}"
343
+ return view_url
344
+
345
+ def _generate_fbx_files(
346
+ self,
347
+ visualization_data: dict,
348
+ output_dir: Optional[str] = None,
349
+ fbx_filename: Optional[str] = None,
350
+ ) -> List[str]:
351
+ assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data"
352
+ fbx_files = []
353
+ if output_dir is None:
354
+ root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
355
+ output_dir = os.path.join(root_dir, "output", "gradio")
356
+
357
+ smpl_data_list = visualization_data["smpl_data"]
358
+
359
+ unique_id = str(uuid.uuid4())[:8]
360
+ text = visualization_data["text"]
361
+ timestamp = visualization_data["timestamp"]
362
+ for bb in range(len(smpl_data_list)):
363
+ smpl_data = smpl_data_list[bb]
364
+ if fbx_filename is None:
365
+ fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx"
366
+ else:
367
+ fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx"
368
+ fbx_path = os.path.join(output_dir, fbx_filename_bb)
369
+ success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path)
370
+ if success:
371
+ fbx_files.append(fbx_path)
372
+ print(f"\t>>> FBX file generated: {fbx_path}")
373
+ txt_path = fbx_path.replace(".fbx", ".txt")
374
+ with open(txt_path, "w", encoding="utf-8") as f:
375
+ f.write(text)
376
+ fbx_files.append(txt_path)
377
+
378
+ return fbx_files
hymotion/utils/type_converter.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def get_module_device(module: nn.Module) -> torch.device:
6
+ """Get the device of a module.
7
+
8
+ Args:
9
+ module (nn.Module): A module contains the parameters.
10
+
11
+ Returns:
12
+ torch.device: The device of the module.
13
+ """
14
+ try:
15
+ next(module.parameters())
16
+ except StopIteration:
17
+ raise ValueError("The input module should contain parameters.")
18
+
19
+ if next(module.parameters()).is_cuda:
20
+ return torch.device(next(module.parameters()).get_device())
21
+
22
+ return torch.device("cpu")
hymotion/utils/visualize_mesh_web.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import threading
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ _FILE_ACCESS_LOCK = threading.Lock()
12
+
13
+
14
+ def sanitize_filename(filename: str) -> str:
15
+ """
16
+ Sanitize filename to prevent path traversal attacks
17
+ Args:
18
+ filename: original filename
19
+ Returns:
20
+ sanitized filename
21
+ """
22
+ if not filename:
23
+ return ""
24
+
25
+ # remove all path traversal characters
26
+ filename = re.sub(r"\.\.(/|\\\\\\)?", "", filename)
27
+ filename = filename.strip("./\\")
28
+
29
+ # only allow letters, numbers, underscores, hyphens and dots
30
+ # dots are only allowed once in the extension
31
+ filename = re.sub(r"[^a-zA-Z0-9_.-]", "", filename)
32
+
33
+ # prevent multiple consecutive dots
34
+ while ".." in filename:
35
+ filename = filename.replace("..", ".")
36
+
37
+ # prevent starting with a dot (hidden file)
38
+ if filename.startswith("."):
39
+ filename = filename[1:]
40
+
41
+ # limit file name length
42
+ if len(filename) > 255:
43
+ filename = filename[:255]
44
+
45
+ return filename
46
+
47
+
48
+ def sanitize_folder_name(folder_name: str) -> str:
49
+ """
50
+ Sanitize folder name to prevent path traversal attacks
51
+ Args:
52
+ folder_name: original folder name
53
+ Returns:
54
+ sanitized folder name
55
+ """
56
+ if not folder_name:
57
+ return "output" # default folder
58
+
59
+ # remove all path traversal characters
60
+ folder_name = re.sub(r"\.\.(/|\\\\\\)?", "", folder_name)
61
+ folder_name = folder_name.strip("./\\")
62
+
63
+ # only allow letters, numbers, underscores, hyphens and slashes (for subdirectories)
64
+ # but need to ensure slashes don't cause path traversal
65
+ folder_name = re.sub(r"[^a-zA-Z0-9_./-]", "", folder_name)
66
+
67
+ # split path and clean each part
68
+ parts = folder_name.split("/")
69
+ cleaned_parts = []
70
+ for part in parts:
71
+ if part and part not in [".", ".."]:
72
+ # clean each part
73
+ part = re.sub(r"[^a-zA-Z0-9_-]", "", part)
74
+ if part:
75
+ cleaned_parts.append(part)
76
+
77
+ # recombine, allow at most 3 levels of directory depth
78
+ if len(cleaned_parts) > 3:
79
+ cleaned_parts = cleaned_parts[:3]
80
+
81
+ return "/".join(cleaned_parts) if cleaned_parts else "output"
82
+
83
+
84
+ def safe_path_join(base_dir: str, *paths: str) -> str:
85
+ """
86
+ Safe path joining, ensure the resulting path is within base_dir
87
+ Args:
88
+ base_dir: base directory
89
+ *paths: paths to join
90
+ Returns:
91
+ joined path
92
+ Raises:
93
+ ValueError: if path traversal is detected
94
+ """
95
+ # clean all paths
96
+ cleaned_paths = []
97
+ for path in paths:
98
+ if path:
99
+ # clean each path part
100
+ path = re.sub(r"\.\.(/|\\\\\\)?", "", path)
101
+ path = path.strip("./\\")
102
+ path = re.sub(r"[^a-zA-Z0-9_.-]", "", path)
103
+ if path:
104
+ cleaned_paths.append(path)
105
+
106
+ # join paths
107
+ full_path = os.path.join(base_dir, *cleaned_paths)
108
+
109
+ # ensure the resulting path is within base_dir
110
+ base_dir = os.path.realpath(base_dir)
111
+ full_path = os.path.realpath(os.path.normpath(full_path))
112
+
113
+ if os.path.commonpath([base_dir, full_path]) != base_dir:
114
+ raise ValueError(f"Path traversal detected: {full_path} is outside {base_dir}")
115
+
116
+ return full_path
117
+
118
+
119
+ def _get_root_dir() -> str:
120
+ return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
121
+
122
+
123
+ def get_output_dir(sub_path: str = "") -> str:
124
+ output_base = _get_root_dir()
125
+ if not os.path.exists(output_base):
126
+ os.makedirs(output_base, exist_ok=True)
127
+ if sub_path:
128
+ parts = [p for p in sub_path.replace("\\", "/").split("/") if p]
129
+ else:
130
+ parts = []
131
+ return safe_path_join(output_base, *parts)
132
+
133
+
134
+ def save_visualization_data(
135
+ output: Dict[str, Union[Tensor, list[str]]],
136
+ text: str,
137
+ rewritten_text: Union[str, list[str]],
138
+ timestamp: str,
139
+ output_dir: Optional[str] = None,
140
+ output_filename: Optional[str] = None,
141
+ ):
142
+ from ..pipeline.body_model import construct_smpl_data_dict
143
+
144
+ if output_dir is None:
145
+ output_dir = get_output_dir(sub_path="output/gradio")
146
+ os.makedirs(output_dir, exist_ok=True)
147
+
148
+ # for metadata
149
+ base_filename = output_filename if output_filename else timestamp
150
+ meta_path = safe_path_join(output_dir, f"{base_filename}_meta.json")
151
+ if isinstance(rewritten_text, str):
152
+ rewritten_text = [rewritten_text]
153
+ batch_size = output["rot6d"].shape[0]
154
+ meta_data = {
155
+ "timestamp": timestamp,
156
+ "text": text,
157
+ "text_rewrite": rewritten_text,
158
+ "num_samples": batch_size,
159
+ "base_filename": base_filename,
160
+ }
161
+
162
+ with _FILE_ACCESS_LOCK:
163
+ with open(meta_path, "w") as f:
164
+ json.dump(meta_data, f, indent=2)
165
+
166
+ # for smpl data
167
+ rot6d = output["rot6d"]
168
+ transl = output["transl"]
169
+
170
+ all_smpl_data = [] # for FBX generator
171
+
172
+ for bb in range(batch_size):
173
+ # build data
174
+ smpl_data = construct_smpl_data_dict(rot6d[bb].clone(), transl[bb].clone())
175
+ all_smpl_data.append(smpl_data)
176
+
177
+ # prepare dictionary to save into NPZ
178
+ npz_dict = {}
179
+ npz_dict["gender"] = np.array([smpl_data.get("gender", "neutral")], dtype=str)
180
+
181
+ for key in ["Rh", "trans", "poses", "betas"]:
182
+ if key in smpl_data:
183
+ val = smpl_data[key]
184
+ if isinstance(val, (list, tuple)):
185
+ val = np.array(val)
186
+ elif isinstance(val, torch.Tensor):
187
+ val = val.cpu().numpy()
188
+ npz_dict[key] = val
189
+
190
+ # save single NPZ
191
+ sample_filename = f"{base_filename}_{bb:03d}.npz"
192
+ sample_path = safe_path_join(output_dir, sample_filename)
193
+
194
+ with _FILE_ACCESS_LOCK:
195
+ np.savez_compressed(sample_path, **npz_dict)
196
+
197
+ # construct memory dictionary to return (for compatibility)
198
+ memory_data = {
199
+ "timestamp": timestamp,
200
+ "text": text,
201
+ "text_rewrite": rewritten_text,
202
+ "smpl_data": all_smpl_data,
203
+ "meta_data": [],
204
+ }
205
+
206
+ # return base filename, subsequent logic will use this as a basis for finding _meta.json or _000.npz
207
+ return memory_data, base_filename
208
+
209
+
210
+ def get_cached_captions(folder_name: str, file_name: str) -> List[dict]:
211
+ """read _meta.json to get text"""
212
+
213
+ folder_name = sanitize_folder_name(folder_name)
214
+ file_name = sanitize_filename(file_name)
215
+
216
+ base_dir = get_output_dir(folder_name)
217
+ # try to add suffix or find
218
+ meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
219
+
220
+ if not os.path.exists(meta_path):
221
+ if "_" in file_name:
222
+ prefix = file_name.rsplit("_", 1)[0]
223
+ prefix = sanitize_filename(prefix)
224
+ meta_path_alt = safe_path_join(base_dir, f"{prefix}_meta.json")
225
+ if os.path.exists(meta_path_alt):
226
+ meta_path = meta_path_alt
227
+ else:
228
+ return []
229
+ else:
230
+ return []
231
+
232
+ try:
233
+ with _FILE_ACCESS_LOCK:
234
+ with open(meta_path, "r") as f:
235
+ data = json.load(f)
236
+
237
+ text = data.get("text", "")
238
+ text_rewrite = data.get("text_rewrite", [])
239
+
240
+ captions = []
241
+ for i, t in enumerate(text_rewrite):
242
+ item = {"short caption+": f"{t}", "start_time": None, "end_time": None}
243
+ if text and text != t:
244
+ item["short caption"] = text
245
+ captions.append(item)
246
+ return captions
247
+ except Exception as e:
248
+ print(f"Error reading meta json: {e}")
249
+ return []
250
+
251
+
252
+ def get_cached_smpl_frames(folder_name: str, file_name: str) -> List[list]:
253
+ """
254
+ read logic needs to be adjusted:
255
+ 1. if file_name is the base name, load all samples
256
+ 2. if file_name is a specific sample name, only load that sample
257
+ """
258
+ folder_name = sanitize_folder_name(folder_name)
259
+ file_name = sanitize_filename(file_name)
260
+
261
+ base_dir = get_output_dir(folder_name)
262
+
263
+ npz_direct_path = safe_path_join(base_dir, f"{file_name}.npz")
264
+ meta_path = safe_path_join(base_dir, f"{file_name}_meta.json")
265
+
266
+ target_indices = []
267
+ base_name = file_name
268
+
269
+ if os.path.isfile(npz_direct_path):
270
+ try:
271
+ if "_" in file_name:
272
+ prefix, suffix = file_name.rsplit("_", 1)
273
+ if suffix.isdigit():
274
+ num_samples = 1
275
+ base_name = prefix
276
+ target_indices = [int(suffix)]
277
+ else:
278
+ pass
279
+ else:
280
+ pass
281
+ except ValueError:
282
+ pass
283
+ if not target_indices:
284
+ return []
285
+ elif os.path.exists(meta_path):
286
+ try:
287
+ with open(meta_path, "r") as f:
288
+ meta = json.load(f)
289
+ num_samples = meta.get("num_samples", 0)
290
+ target_indices = range(num_samples)
291
+ except Exception as e:
292
+ print(f"Error reading meta: {e}")
293
+ return []
294
+ else:
295
+ return []
296
+
297
+ all_people = []
298
+
299
+ for i in target_indices:
300
+ npz_path = safe_path_join(base_dir, f"{base_name}_{i:03d}.npz")
301
+ if not os.path.exists(npz_path):
302
+ continue
303
+
304
+ try:
305
+ with _FILE_ACCESS_LOCK:
306
+ with np.load(npz_path, allow_pickle=False) as data:
307
+ # read single person data
308
+ gender = str(data["gender"][0])
309
+ Rh = data["Rh"]
310
+ Th = data["trans"]
311
+ poses = data["poses"]
312
+ betas = data["betas"]
313
+
314
+ if poses.ndim == 3:
315
+ poses = poses.reshape(poses.shape[0], -1)
316
+
317
+ person_frames = []
318
+ for f in range(len(poses)):
319
+ frame = {
320
+ "id": i,
321
+ "gender": gender,
322
+ "Rh": Rh[f : f + 1].tolist(),
323
+ "Th": Th[f : f + 1].tolist(),
324
+ "poses": poses[f : f + 1].tolist(),
325
+ "shapes": betas.tolist(),
326
+ }
327
+ person_frames.append([frame])
328
+ all_people.append(person_frames)
329
+ except Exception as e:
330
+ print(f"Error loading {npz_path}: {e}")
331
+
332
+ # merge
333
+ combined_frames = []
334
+ max_frames = max(len(p) for p in all_people) if all_people else 0
335
+ for f_idx in range(max_frames):
336
+ frame_content = []
337
+ for person_seq in all_people:
338
+ if f_idx < len(person_seq):
339
+ frame_content.extend(person_seq[f_idx])
340
+ combined_frames.append(frame_content)
341
+
342
+ return combined_frames
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp ADDED

Git LFS Details

  • SHA256: aaed26cb89635f9d995c9c373919185c412f00f6b4a903873ac75ccd7a549439
  • Pointer size: 131 Bytes
  • Size of remote file: 228 kB
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp ADDED

Git LFS Details

  • SHA256: fe4bd8b80aadf6e414c3c07acc57532cda754deb93f0cdc6f211387b8beca97d
  • Pointer size: 130 Bytes
  • Size of remote file: 30.4 kB
scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp ADDED

Git LFS Details

  • SHA256: 9ad8bd31dfda3a8356ac5c9d6bdc7457f0197ae23754f9ac196f86d872b26781
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
scripts/gradio/static/assets/dump_wooden/faces.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:777b0806d2843c797ed18644eecc11466ff822b33b02d263c22f8ad3730e9bb5
3
+ size 290376
scripts/gradio/static/assets/dump_wooden/j_template.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
3
+ size 624
scripts/gradio/static/assets/dump_wooden/joint_names.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "Pelvis",
3
+ "L_Hip",
4
+ "R_Hip",
5
+ "Spine1",
6
+ "L_Knee",
7
+ "R_Knee",
8
+ "Spine2",
9
+ "L_Ankle",
10
+ "R_Ankle",
11
+ "Spine3",
12
+ "L_Foot",
13
+ "R_Foot",
14
+ "Neck",
15
+ "L_Collar",
16
+ "R_Collar",
17
+ "Head",
18
+ "L_Shoulder",
19
+ "R_Shoulder",
20
+ "L_Elbow",
21
+ "R_Elbow",
22
+ "L_Wrist",
23
+ "R_Wrist",
24
+ "L_Index1",
25
+ "L_Index2",
26
+ "L_Index3",
27
+ "L_Middle1",
28
+ "L_Middle2",
29
+ "L_Middle3",
30
+ "L_Pinky1",
31
+ "L_Pinky2",
32
+ "L_Pinky3",
33
+ "L_Ring1",
34
+ "L_Ring2",
35
+ "L_Ring3",
36
+ "L_Thumb1",
37
+ "L_Thumb2",
38
+ "L_Thumb3",
39
+ "R_Index1",
40
+ "R_Index2",
41
+ "R_Index3",
42
+ "R_Middle1",
43
+ "R_Middle2",
44
+ "R_Middle3",
45
+ "R_Pinky1",
46
+ "R_Pinky2",
47
+ "R_Pinky3",
48
+ "R_Ring1",
49
+ "R_Ring2",
50
+ "R_Ring3",
51
+ "R_Thumb1",
52
+ "R_Thumb2",
53
+ "R_Thumb3"
54
+ ]
scripts/gradio/static/assets/dump_wooden/joints.ply ADDED
Binary file (782 Bytes). View file
 
scripts/gradio/static/assets/dump_wooden/keypoints.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929
3
+ size 624
scripts/gradio/static/assets/dump_wooden/kintree.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98a20fa3b53193790b63d9ac3a9c917f2f70fbe6e053dca495c75317ff4b756a
3
+ size 208
scripts/gradio/static/assets/dump_wooden/skinIndice.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:846794fb90ea01e069435ad242caefb3d5c2f913fef3255c247f48836ddc1bda
3
+ size 194048
scripts/gradio/static/assets/dump_wooden/skinWeights.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:343eac2902627ee6b45b547eb7d9f1526562eca7ba178d1dc71f9e466f307a77
3
+ size 388096
scripts/gradio/static/assets/dump_wooden/uvs.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23c92c60b609261d927990d31cbf6ace0c14cb70a5ac753a5f3927cb8c5c8191
3
+ size 194048
scripts/gradio/static/assets/dump_wooden/v_template.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9fbe1f34bfe8a07442d11166e169318022a18da8bc62ce0a9930dfdb3171050
3
+ size 291072
scripts/gradio/static/scripts3d/create_ground.js ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from "three";
2
+
3
+ // extract common adaptive logic
4
+ function getAdaptiveGridSize(sample_data, default_size = 5) {
5
+ if (sample_data) {
6
+ const bounds = calculateDataBounds(sample_data);
7
+ const grid_size = Math.max(bounds.maxRange * 3, 5); // 1.5x margin
8
+ console.log(`Adaptive ground size: ${grid_size.toFixed(2)}, data range: ${bounds.maxRange.toFixed(2)}`);
9
+ return grid_size;
10
+ }
11
+ return default_size;
12
+ }
13
+
14
+ function createBaseChessboard(
15
+ grid_size = 5,
16
+ divisions = 10,
17
+ white = "#ffffff",
18
+ black = "#444444",
19
+ texture_size = 1024,
20
+ sample_data = null,
21
+ ) {
22
+ // Use adaptive sizing if sample_data provided, otherwise use fixed grid_size
23
+ if (sample_data) {
24
+ grid_size = getAdaptiveGridSize(sample_data, grid_size);
25
+ }
26
+
27
+ // Create chessboard texture with enhanced visual style
28
+ // Ensure texture_size is divisible by divisions to avoid sub-pixel rendering
29
+ var adjusted_texture_size = Math.floor(texture_size / divisions) * divisions;
30
+ var canvas = document.createElement("canvas");
31
+ canvas.width = canvas.height = adjusted_texture_size;
32
+ var context = canvas.getContext("2d");
33
+
34
+ // Disable anti-aliasing for crisp edges
35
+ context.imageSmoothingEnabled = false;
36
+
37
+ var step = adjusted_texture_size / divisions; // Now guaranteed to be an integer
38
+ for (var i = 0; i < divisions; i++) {
39
+ for (var j = 0; j < divisions; j++) {
40
+ context.fillStyle = (i + j) % 2 === 0 ? white : black;
41
+ context.fillRect(i * step, j * step, step, step);
42
+ }
43
+ }
44
+
45
+ var texture = new THREE.CanvasTexture(canvas);
46
+ // Use NearestFilter for sharp/crisp edges between chess squares
47
+ texture.wrapS = THREE.RepeatWrapping;
48
+ texture.wrapT = THREE.RepeatWrapping;
49
+ texture.magFilter = THREE.NearestFilter;
50
+ texture.minFilter = THREE.NearestFilter;
51
+ texture.generateMipmaps = false;
52
+
53
+ // Create plane geometry
54
+ var planeGeometry = new THREE.PlaneGeometry(grid_size, grid_size);
55
+
56
+ // Enhanced material with better visual properties
57
+ var planeMaterial = new THREE.MeshStandardMaterial({
58
+ map: texture,
59
+ side: THREE.DoubleSide,
60
+ transparent: true,
61
+ opacity: 0.85,
62
+ roughness: 0.9,
63
+ metalness: 0.1,
64
+ emissiveIntensity: 0.05,
65
+ });
66
+
67
+ // Create grid mesh
68
+ var plane = new THREE.Mesh(planeGeometry, planeMaterial);
69
+ plane.receiveShadow = true;
70
+
71
+ return plane;
72
+ }
73
+
74
+ function getChessboard(...args) {
75
+ var plane = createBaseChessboard(...args);
76
+ plane.rotation.x = -Math.PI; // rotate to make the plane horizontal
77
+ return plane;
78
+ }
79
+
80
+ function getChessboardXZ(...args) {
81
+ var plane = createBaseChessboard(...args);
82
+ plane.rotation.x = -Math.PI / 2; // rotate to make the plane horizontal
83
+ return plane;
84
+ }
85
+
86
+ function getCoordinate(axisLength) {
87
+ // create a group to store the coordinate axes
88
+ var axes = new THREE.Group();
89
+
90
+ // define the material of the axes
91
+ var materialX = new THREE.LineBasicMaterial({ color: 0xff0000 }); // red X axis
92
+ var materialY = new THREE.LineBasicMaterial({ color: 0x00ff00 }); // green Y axis
93
+ var materialZ = new THREE.LineBasicMaterial({ color: 0x0000ff }); // blue Z axis
94
+
95
+ // create axis lines (X axis, Y axis, Z axis)
96
+ var xAxisGeometry = new THREE.BufferGeometry().setFromPoints([
97
+ new THREE.Vector3(0, 0, 0),
98
+ new THREE.Vector3(axisLength, 0, 0),
99
+ ]);
100
+ var yAxisGeometry = new THREE.BufferGeometry().setFromPoints([
101
+ new THREE.Vector3(0, 0, 0),
102
+ new THREE.Vector3(0, axisLength, 0),
103
+ ]);
104
+ var zAxisGeometry = new THREE.BufferGeometry().setFromPoints([
105
+ new THREE.Vector3(0, 0, 0),
106
+ new THREE.Vector3(0, 0, axisLength),
107
+ ]);
108
+
109
+ var xAxis = new THREE.Line(xAxisGeometry, materialX);
110
+ var yAxis = new THREE.Line(yAxisGeometry, materialY);
111
+ var zAxis = new THREE.Line(zAxisGeometry, materialZ);
112
+
113
+ // add axes to the group
114
+ axes.add(xAxis);
115
+ axes.add(yAxis);
116
+ axes.add(zAxis);
117
+
118
+ return axes;
119
+ }
120
+
121
+ function calculateDataBounds(sample_data) {
122
+ let minX = Infinity,
123
+ maxX = -Infinity;
124
+ let minY = Infinity,
125
+ maxY = -Infinity;
126
+ let minZ = Infinity,
127
+ maxZ = -Infinity;
128
+
129
+ // iterate through sample_data to find the maximum and minimum values
130
+ if (sample_data && sample_data.length > 0) {
131
+ sample_data.forEach((frame) => {
132
+ if (frame.positions && Array.isArray(frame.positions)) {
133
+ frame.positions.forEach((pos) => {
134
+ // support multiple position data formats
135
+ let x, y, z;
136
+ if (typeof pos === "object") {
137
+ x = pos.x !== undefined ? pos.x : pos[0];
138
+ y = pos.y !== undefined ? pos.y : pos[1];
139
+ z = pos.z !== undefined ? pos.z : pos[2];
140
+ } else if (Array.isArray(pos)) {
141
+ [x, y, z] = pos;
142
+ }
143
+
144
+ if (x !== undefined && y !== undefined && z !== undefined) {
145
+ minX = Math.min(minX, x);
146
+ maxX = Math.max(maxX, x);
147
+ minY = Math.min(minY, y);
148
+ maxY = Math.max(maxY, y);
149
+ minZ = Math.min(minZ, z);
150
+ maxZ = Math.max(maxZ, z);
151
+ }
152
+ });
153
+ }
154
+ });
155
+ }
156
+
157
+ // if no valid data is found, use default values
158
+ if (minX === Infinity || maxX === -Infinity) {
159
+ minX = maxX = minY = maxY = minZ = maxZ = 0;
160
+ }
161
+
162
+ const rangeX = Math.abs(maxX - minX);
163
+ const rangeY = Math.abs(maxY - minY);
164
+ const rangeZ = Math.abs(maxZ - minZ);
165
+
166
+ // calculate the maximum range of the XZ plane (the ground mainly cares about the movement of the X and Z axes)
167
+ const maxRange = Math.max(rangeX, rangeZ);
168
+
169
+ // add debug information
170
+ console.log(
171
+ `Data boundaries: X[${minX.toFixed(2)}, ${maxX.toFixed(2)}], Y[${minY.toFixed(2)}, ${maxY.toFixed(2)}], Z[${minZ.toFixed(2)}, ${maxZ.toFixed(2)}]`,
172
+ );
173
+ console.log(
174
+ `Ranges: X=${rangeX.toFixed(2)}, Y=${rangeY.toFixed(2)}, Z=${rangeZ.toFixed(2)}, Max=${maxRange.toFixed(2)}`,
175
+ );
176
+
177
+ return {
178
+ minX,
179
+ maxX,
180
+ minY,
181
+ maxY,
182
+ minZ,
183
+ maxZ,
184
+ rangeX,
185
+ rangeY,
186
+ rangeZ,
187
+ maxRange,
188
+ };
189
+ }
190
+
191
+ export { calculateDataBounds, getChessboard, getChessboardXZ, getCoordinate };
scripts/gradio/static/scripts3d/create_scene.js ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from "three";
2
+ import { getChessboard, getChessboardXZ, getCoordinate } from "./create_ground.js";
3
+
4
+ function create_plane(scene) {
5
+ const planeGeometry = new THREE.PlaneGeometry(20, 20);
6
+ const planeMaterial = new THREE.MeshStandardMaterial({ color: 0x808080 });
7
+ const plane = new THREE.Mesh(planeGeometry, planeMaterial);
8
+ plane.position.y = -1;
9
+ plane.receiveShadow = true; // make the plane receive shadows
10
+ scene.add(plane);
11
+ }
12
+
13
+ function create_cube(scene) {
14
+ // add a cube
15
+ const cubeGeometry = new THREE.BoxGeometry();
16
+ const cubeMaterial = new THREE.MeshPhongMaterial({ color: 0xffffff });
17
+ const cube = new THREE.Mesh(cubeGeometry, cubeMaterial);
18
+ cube.position.y = 1;
19
+ cube.castShadow = true; // make the cube cast shadows
20
+ scene.add(cube);
21
+ }
22
+
23
+ function create_scene(scene, camera, renderer, use_ground = true, axis_up = "z", axis_forward = "-y") {
24
+ const width = document.querySelector(".container").offsetWidth;
25
+ const height = width;
26
+
27
+ // Camera setup based on axis orientation
28
+ if (axis_up == "z") {
29
+ camera.up.set(0, 0, 1);
30
+ if (axis_forward == "-y") {
31
+ camera.position.set(0, -3, 3);
32
+ } else if (axis_forward == "y") {
33
+ camera.position.set(0, 3, 3);
34
+ }
35
+ camera.lookAt(new THREE.Vector3(0, 0, 1.5));
36
+ } else if (axis_up == "y") {
37
+ camera.up.set(0, 1, 0);
38
+ if (axis_forward == "z") {
39
+ camera.position.set(0, 2.5, 5);
40
+ } else if (axis_forward == "-z") {
41
+ camera.position.set(0, 2.5, -5);
42
+ }
43
+ camera.lookAt(new THREE.Vector3(0, 1, 0));
44
+ }
45
+
46
+ scene.background = new THREE.Color(0x000000);
47
+
48
+ // ===== Fog for depth perception =====
49
+ // Using FogExp2 for natural exponential falloff, density ~0.06
50
+ scene.fog = new THREE.FogExp2(0x424242, 0.06);
51
+
52
+ // ===== Shadow Configuration =====
53
+ renderer.shadowMap.enabled = true;
54
+ renderer.shadowMap.type = THREE.PCFSoftShadowMap;
55
+
56
+ // ===== Enhanced Lighting Setup =====
57
+
58
+ // 1. Hemisphere Light - natural sky/ground ambient
59
+ const hemisphereLight = new THREE.HemisphereLight(
60
+ 0xffffff, // sky color
61
+ 0x444444, // ground color
62
+ 1.8 // intensity
63
+ );
64
+ hemisphereLight.position.set(0, 2, 0);
65
+ scene.add(hemisphereLight);
66
+
67
+ // 2. Main Directional Light (key light with shadows)
68
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5);
69
+ if (axis_up == "z") {
70
+ if (axis_forward == "-y") {
71
+ directionalLight.position.set(-3, 1, 5);
72
+ } else if (axis_forward == "y") {
73
+ directionalLight.position.set(3, 1, 5);
74
+ }
75
+ } else if (axis_up == "y") {
76
+ if (axis_forward == "z") {
77
+ directionalLight.position.set(3, 5, 4);
78
+ } else if (axis_forward == "-z") {
79
+ directionalLight.position.set(3, 5, -4);
80
+ }
81
+ }
82
+ directionalLight.castShadow = true;
83
+ directionalLight.shadow.mapSize.width = 2048;
84
+ directionalLight.shadow.mapSize.height = 2048;
85
+ directionalLight.shadow.camera.near = 0.5;
86
+ directionalLight.shadow.camera.far = 50;
87
+ directionalLight.shadow.camera.left = -10;
88
+ directionalLight.shadow.camera.right = 10;
89
+ directionalLight.shadow.camera.top = 10;
90
+ directionalLight.shadow.camera.bottom = -10;
91
+ directionalLight.shadow.bias = -0.0001;
92
+ scene.add(directionalLight);
93
+
94
+ // 3. Fill Light (softer, from opposite side)
95
+ const fillLight = new THREE.DirectionalLight(0xaaccff, 0.4);
96
+ fillLight.position.set(-3, 3, -2);
97
+ scene.add(fillLight);
98
+
99
+ // 4. Rim Light (back light for depth)
100
+ const rimLight = new THREE.DirectionalLight(0xffeedd, 0.3);
101
+ rimLight.position.set(0, 4, -5);
102
+ scene.add(rimLight);
103
+
104
+ // ===== Ground Setup =====
105
+ if (use_ground) {
106
+ if (axis_up == "z") {
107
+ var plane = getChessboard(50, 50, '#ffffff', '#3a3a3a', 1024);
108
+ plane.name = 'ground';
109
+ plane.receiveShadow = true;
110
+ scene.add(plane);
111
+ } else if (axis_up == "y") {
112
+ var plane = getChessboardXZ(50, 50, '#ffffff', '#3a3a3a', 1024);
113
+ plane.name = 'ground';
114
+ plane.receiveShadow = true;
115
+ scene.add(plane);
116
+ }
117
+
118
+ // Optional: coordinate axes helper
119
+ // var coord = getCoordinate(1);
120
+ // scene.add(coord);
121
+ }
122
+
123
+ return 0;
124
+ }
125
+
126
+ function fitCameraToScene(scene, camera, controls = null, opts = {}) {
127
+ const { margin = 1.05, axis_up = "y", excludeNames = ["ground"] } = opts;
128
+
129
+ const box = new THREE.Box3();
130
+ const tmp = new THREE.Box3();
131
+ let has = false;
132
+
133
+ scene.traverse((obj) => {
134
+ if (!obj || !obj.visible) return;
135
+ if (obj.isLight) return;
136
+ const t = obj.type || "";
137
+ if (t.endsWith("Helper")) return;
138
+ if (excludeNames && excludeNames.includes(obj.name)) return;
139
+
140
+ if (obj.isMesh) {
141
+ if (obj.geometry && obj.geometry.type === "PlaneGeometry") return;
142
+ try {
143
+ tmp.setFromObject(obj);
144
+ if (!tmp.isEmpty()) {
145
+ if (!has) {
146
+ box.copy(tmp);
147
+ has = true;
148
+ } else {
149
+ box.union(tmp);
150
+ }
151
+ }
152
+ } catch (_) {}
153
+ }
154
+ });
155
+
156
+ if (!has || box.isEmpty()) return;
157
+
158
+ const sphere = new THREE.Sphere();
159
+ box.getBoundingSphere(sphere);
160
+ const center = sphere.center.clone();
161
+ const radius = Math.max(sphere.radius, 1e-3);
162
+
163
+ const vFov = THREE.MathUtils.degToRad(camera.fov);
164
+ const hFov = 2 * Math.atan(Math.tan(vFov / 2) * camera.aspect);
165
+ const distV = radius / Math.sin(vFov / 2);
166
+ const distH = radius / Math.sin(hFov / 2);
167
+ const dist = Math.max(distV, distH) * margin;
168
+
169
+ // 25° top-down view (azimuth 45°, elevation 25°)
170
+ const elev = THREE.MathUtils.degToRad(25);
171
+ const azim = Math.PI / 4;
172
+ const horiz = Math.cos(elev);
173
+ let dir;
174
+
175
+ if (axis_up === "y") {
176
+ dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.sin(elev), Math.cos(azim) * horiz);
177
+ camera.up.set(0, 1, 0);
178
+ } else {
179
+ dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.cos(azim) * horiz, Math.sin(elev));
180
+ camera.up.set(0, 0, 1);
181
+ }
182
+
183
+ camera.position.copy(center).add(dir.multiplyScalar(dist));
184
+ camera.updateProjectionMatrix();
185
+ camera.lookAt(center);
186
+
187
+ if (controls) {
188
+ controls.target.copy(center);
189
+ controls.minDistance = Math.max(radius * 0.2, 0.1);
190
+ controls.maxDistance = Math.max(dist * 3, controls.minDistance + 0.1);
191
+ controls.update();
192
+ }
193
+ }
194
+
195
+ export { create_scene, fitCameraToScene };
scripts/gradio/static/scripts3d/draw_skeleton.js ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from "three";
2
+
3
+ const defaultEdges = [
4
+ [1, 0],
5
+ [2, 1],
6
+ [3, 2],
7
+ [4, 3],
8
+ [5, 1],
9
+ [6, 5],
10
+ [7, 6],
11
+ [8, 1],
12
+ [9, 8],
13
+ [10, 9],
14
+ [11, 10],
15
+ [12, 8],
16
+ [13, 12],
17
+ [14, 13],
18
+ [15, 0],
19
+ [16, 0],
20
+ [17, 15],
21
+ [18, 16],
22
+ [19, 14],
23
+ [20, 19],
24
+ [21, 14],
25
+ [22, 11],
26
+ [23, 22],
27
+ [24, 11],
28
+ ];
29
+
30
+ var geometries = [];
31
+
32
+ function clearGeometries(scene) {
33
+ geometries.forEach((obj) => {
34
+ scene.remove(obj);
35
+ if (obj.geometry) obj.geometry.dispose();
36
+ if (obj.material) obj.material.dispose();
37
+ });
38
+ geometries = [];
39
+ }
40
+
41
+ function drawJoints(keypoints, scene, radius_joint) {
42
+ const sphereGeometry = new THREE.SphereGeometry(radius_joint, 32, 32);
43
+ const sphereMaterial = new THREE.MeshStandardMaterial({ color: 0xff0000 });
44
+
45
+ keypoints.forEach((point) => {
46
+ // Check visibility if confidence score exists
47
+ if (point.length > 3 && point[3] < 0.1) {
48
+ return;
49
+ }
50
+
51
+ const sphere = new THREE.Mesh(sphereGeometry, sphereMaterial);
52
+ sphere.position.set(point[0], point[1], point[2]);
53
+ geometries.push(sphere);
54
+ scene.add(sphere);
55
+ });
56
+ }
57
+
58
+ function drawLimbs(keypoints, edges, scene, radius_limb) {
59
+ const ellipsoidGeometry = new THREE.SphereGeometry(radius_limb, 32, 32);
60
+ const ellipsoidMaterial = new THREE.MeshStandardMaterial({ color: 0x0000ff });
61
+
62
+ edges.forEach((edge) => {
63
+ const idx1 = edge[0];
64
+ const idx2 = edge[1];
65
+
66
+ // Validate indices
67
+ if (idx1 >= keypoints.length || idx2 >= keypoints.length) {
68
+ return;
69
+ }
70
+
71
+ // Check visibility
72
+ const p1 = keypoints[idx1];
73
+ const p2 = keypoints[idx2];
74
+ if (
75
+ (p1.length > 3 && p1[3] < 0.1) ||
76
+ (p2.length > 3 && p2[3] < 0.1)
77
+ ) {
78
+ return;
79
+ }
80
+
81
+ const start = new THREE.Vector3(p1[0], p1[1], p1[2]);
82
+ const end = new THREE.Vector3(p2[0], p2[1], p2[2]);
83
+
84
+ const direction = new THREE.Vector3().subVectors(end, start);
85
+ const length = direction.length();
86
+
87
+ // create an ellipsoid
88
+ const ellipsoid = new THREE.Mesh(ellipsoidGeometry, ellipsoidMaterial);
89
+
90
+ // scale: x,y = 1 (radius_limb), z matches length
91
+ ellipsoid.scale.set(1, 1, length / 2 / radius_limb);
92
+
93
+ // position: midpoint
94
+ ellipsoid.position.addVectors(start, end).multiplyScalar(0.5);
95
+
96
+ // rotation: point to end
97
+ ellipsoid.lookAt(end);
98
+
99
+ geometries.push(ellipsoid);
100
+ scene.add(ellipsoid);
101
+ });
102
+ }
103
+
104
+ function drawSingleSkeleton(keypoints, edges, scene, radius_joint, radius_limb) {
105
+ drawJoints(keypoints, scene, radius_joint);
106
+ drawLimbs(keypoints, edges, scene, radius_limb);
107
+ }
108
+
109
+ function visualizeSkeleton(keypoints, scene, radius_joint = 0.02, radius_limb = 0.03) {
110
+ clearGeometries(scene);
111
+ drawSingleSkeleton(keypoints, defaultEdges, scene, radius_joint, radius_limb);
112
+ }
113
+
114
+ function visualizeAllSkeleton(infos, scene, radius_joint = 0.02, radius_limb = 0.03) {
115
+ clearGeometries(scene);
116
+ infos.forEach((info) => {
117
+ drawSingleSkeleton(info.keypoints3d, info.edges, scene, radius_joint, radius_limb);
118
+ });
119
+ }
120
+
121
+ export { visualizeAllSkeleton, visualizeSkeleton };
scripts/gradio/static/scripts3d/load_smpl.js ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from "three";
2
+
3
+ const NUM_SKIN_WEIGHTS = 4;
4
+
5
+ async function load_smpl_with_shapes(shapes, gender) {
6
+ const urls = {
7
+ neutral: [
8
+ "/static/assets/dump_smplh/v_template.bin",
9
+ "/static/assets/dump_smplh/faces.bin",
10
+ "/static/assets/dump_smplh/skinWeights.bin",
11
+ "/static/assets/dump_smplh/skinIndice.bin",
12
+ "/static/assets/dump_smplh/j_template.bin",
13
+ ],
14
+ }[gender];
15
+ const gender_color = {
16
+ neutral: 0xffffff,
17
+ male: 0x6495ed, // Cornflower blue (lighter blue)
18
+ female: 0xff6b81, // Light coral (softer red)
19
+ };
20
+
21
+ console.log(shapes.length);
22
+ const geometry = new THREE.BufferGeometry();
23
+ const buffers = await Promise.all(urls.map((url) => fetch(url).then((response) => response.arrayBuffer())));
24
+ const v_template = new Float32Array(buffers[0]);
25
+ const offsets = await Promise.all(
26
+ shapes.map((_, i) =>
27
+ fetch("/static/assets/dump_smplh/shapeoffset_" + i + ".bin")
28
+ .then((response) => response.arrayBuffer())
29
+ .then((buffer) => new Float32Array(buffer)),
30
+ ),
31
+ );
32
+ const offsets_j = await Promise.all(
33
+ shapes.map((_, i) =>
34
+ fetch("/static/assets/dump_smplh/shapeoffset_j_" + i + ".bin")
35
+ .then((response) => response.arrayBuffer())
36
+ .then((buffer) => new Float32Array(buffer)),
37
+ ),
38
+ );
39
+ offsets.forEach((offset, i) => {
40
+ for (let j = 0; j < v_template.length / 3; j++) {
41
+ v_template[3 * j] += offset[3 * j] * shapes[i];
42
+ v_template[3 * j + 1] += offset[3 * j + 1] * shapes[i];
43
+ v_template[3 * j + 2] += offset[3 * j + 2] * shapes[i];
44
+ }
45
+ });
46
+ const faces = new Uint16Array(buffers[1]);
47
+ const skinWeights = new Float32Array(buffers[2]);
48
+ const skinIndices = new Uint16Array(buffers[3]);
49
+
50
+ const keypoints = new Float32Array(buffers[4]);
51
+ for (let i = 0; i < keypoints.length / 3; i++) {
52
+ console.log("keypoints", keypoints[3 * i], keypoints[3 * i + 1], keypoints[3 * i + 2]);
53
+ }
54
+
55
+ offsets_j.forEach((offset_j, i) => {
56
+ console.log("shape id", i, shapes[i]);
57
+ console.log("keypoints", keypoints[0], keypoints[1], keypoints[2]);
58
+ console.log("offset_j", offset_j[0], offset_j[1], offset_j[2]);
59
+
60
+ for (let j = 0; j < keypoints.length / 3; j++) {
61
+ keypoints[3 * j] += offset_j[3 * j] * shapes[i];
62
+ keypoints[3 * j + 1] += offset_j[3 * j + 1] * shapes[i];
63
+ keypoints[3 * j + 2] += offset_j[3 * j + 2] * shapes[i];
64
+ }
65
+ });
66
+
67
+ // edges contain the skeleton link relationship
68
+ // const edges = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21];
69
+ const edges = [
70
+ -1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20,
71
+ 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50,
72
+ ];
73
+ // assume jointPositions is a J x 3 array, each element is an array containing X, Y, Z coordinates
74
+ var rootBone = new THREE.Bone();
75
+ rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]);
76
+ // scene.add(rootBone);
77
+ var bones = [rootBone];
78
+ // create bones
79
+ for (let i = 1; i < keypoints.length / 3; i++) {
80
+ const bone = new THREE.Bone();
81
+ const parentIndex = edges[i];
82
+ bone.position.set(
83
+ keypoints[3 * i] - keypoints[3 * parentIndex],
84
+ keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1],
85
+ keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2],
86
+ );
87
+ console.log(i, bone.position);
88
+ bones.push(bone);
89
+ bones[parentIndex].add(bone);
90
+ }
91
+ var skeleton = new THREE.Skeleton(bones);
92
+ geometry.setIndex(new THREE.BufferAttribute(faces, 1));
93
+
94
+ geometry.setAttribute("position", new THREE.BufferAttribute(v_template, 3));
95
+ geometry.setAttribute("skinIndex", new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS));
96
+ geometry.setAttribute("skinWeight", new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS));
97
+
98
+ geometry.computeVertexNormals();
99
+ console.log(geometry);
100
+ const material = new THREE.MeshStandardMaterial({
101
+ color: gender_color[gender],
102
+ skinning: true,
103
+ side: THREE.DoubleSide,
104
+ });
105
+ var mesh = new THREE.SkinnedMesh(geometry, material);
106
+ mesh.castShadow = true;
107
+ mesh.receiveShadow = true;
108
+ mesh.add(bones[0]);
109
+ mesh.bind(skeleton);
110
+ return { bones, skeleton, mesh };
111
+ }
112
+
113
+ function reshapeArrayTo2D(float32Array, rows) {
114
+ const twoDArray = [];
115
+ const cols = float32Array.length / rows;
116
+ for (let i = 0; i < rows; i++) {
117
+ const row = new Float32Array(cols);
118
+ for (let j = 0; j < cols; j++) {
119
+ row[j] = float32Array[i * cols + j];
120
+ }
121
+ twoDArray.push(row);
122
+ }
123
+ return twoDArray;
124
+ }
125
+
126
+ export { load_smpl_with_shapes };
scripts/gradio/static/scripts3d/load_wooden.js ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import * as THREE from 'three';
2
+
3
+ const NUM_SKIN_WEIGHTS = 4;
4
+
5
+ // SMPL-H joint names (52 joints)
6
+ const SMPLH_JOINT_NAMES = [
7
+ "Pelvis", "L_Hip", "R_Hip", "Spine1",
8
+ "L_Knee", "R_Knee", "Spine2",
9
+ "L_Ankle", "R_Ankle", "Spine3",
10
+ "L_Foot", "R_Foot", "Neck", "L_Collar", "R_Collar", "Head",
11
+ "L_Shoulder", "R_Shoulder", "L_Elbow", "R_Elbow",
12
+ "L_Wrist", "R_Wrist",
13
+ "L_Index1", "L_Index2", "L_Index3",
14
+ "L_Middle1", "L_Middle2", "L_Middle3",
15
+ "L_Pinky1", "L_Pinky2", "L_Pinky3",
16
+ "L_Ring1", "L_Ring2", "L_Ring3",
17
+ "L_Thumb1", "L_Thumb2", "L_Thumb3",
18
+ "R_Index1", "R_Index2", "R_Index3",
19
+ "R_Middle1", "R_Middle2", "R_Middle3",
20
+ "R_Pinky1", "R_Pinky2", "R_Pinky3",
21
+ "R_Ring1", "R_Ring2", "R_Ring3",
22
+ "R_Thumb1", "R_Thumb2", "R_Thumb3",
23
+ ];
24
+
25
+ // Default kintree (parent indices) for SMPL-H 52 joints
26
+ const DEFAULT_EDGES = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50];
27
+
28
+ /**
29
+ * Load wooden model from binary files
30
+ * @param {Array} shapes - Shape parameters (unused for wooden model)
31
+ * @param {string} gender - Gender parameter (unused for wooden model)
32
+ * @returns {Object} { bones, skeleton, mesh, jointNames }
33
+ */
34
+ async function load_wooden(shapes, gender, basePath = '/static/assets/dump_wooden') {
35
+ console.log("Loading wooden model...");
36
+ console.log(`Using base path: ${basePath}`);
37
+
38
+ const urls = [
39
+ `${basePath}/v_template.bin`,
40
+ `${basePath}/faces.bin`,
41
+ `${basePath}/skinWeights.bin`,
42
+ `${basePath}/skinIndice.bin`,
43
+ `${basePath}/j_template.bin`,
44
+ `${basePath}/uvs.bin`,
45
+ ];
46
+
47
+ // Try to load kintree
48
+ let edges = [...DEFAULT_EDGES];
49
+ try {
50
+ const kintreeResponse = await fetch(`${basePath}/kintree.bin`);
51
+ if (kintreeResponse.ok) {
52
+ const kintreeBuffer = await kintreeResponse.arrayBuffer();
53
+ edges = Array.from(new Int32Array(kintreeBuffer));
54
+ console.log(`Loaded kintree with ${edges.length} joints`);
55
+ }
56
+ } catch (e) {
57
+ console.log('Using default kintree');
58
+ }
59
+
60
+ // Try to load joint names
61
+ let jointNames = [...SMPLH_JOINT_NAMES];
62
+ try {
63
+ const namesResponse = await fetch(`${basePath}/joint_names.json`);
64
+ if (namesResponse.ok) {
65
+ jointNames = await namesResponse.json();
66
+ console.log(`Loaded ${jointNames.length} joint names`);
67
+ }
68
+ } catch (e) {
69
+ console.log('Using default joint names');
70
+ }
71
+
72
+ // Load main buffers
73
+ const buffers = await Promise.all(urls.map(url => fetch(url).then(response => response.arrayBuffer())));
74
+ const v_template = new Float32Array(buffers[0]);
75
+ const faces = new Uint16Array(buffers[1]);
76
+ const skinWeights = new Float32Array(buffers[2]);
77
+ const skinIndices = new Uint16Array(buffers[3]);
78
+ const keypoints = new Float32Array(buffers[4]);
79
+ const uvs = new Float32Array(buffers[5]);
80
+
81
+ console.log(`Vertices: ${v_template.length / 3}, Faces: ${faces.length / 3}, Joints: ${keypoints.length / 3}`);
82
+
83
+ // Create geometry
84
+ const geometry = new THREE.BufferGeometry();
85
+ geometry.setAttribute('position', new THREE.BufferAttribute(v_template, 3));
86
+ geometry.setIndex(new THREE.BufferAttribute(faces, 1));
87
+ geometry.setAttribute('skinIndex', new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS));
88
+ geometry.setAttribute('skinWeight', new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS));
89
+ geometry.setAttribute('uv', new THREE.BufferAttribute(uvs, 2));
90
+
91
+ // Create bones
92
+ const numJoints = keypoints.length / 3;
93
+
94
+ // Ensure edges array matches joint count
95
+ while (edges.length < numJoints) {
96
+ edges.push(0);
97
+ }
98
+
99
+ // Root bone
100
+ var rootBone = new THREE.Bone();
101
+ rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]);
102
+ rootBone.name = jointNames[0] || 'Pelvis';
103
+ var bones = [rootBone];
104
+
105
+ // Create child bones
106
+ for (let i = 1; i < numJoints; i++) {
107
+ const bone = new THREE.Bone();
108
+ const parentIndex = edges[i];
109
+
110
+ if (parentIndex >= 0 && parentIndex < i) {
111
+ bone.position.set(
112
+ keypoints[3 * i] - keypoints[3 * parentIndex],
113
+ keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1],
114
+ keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2]
115
+ );
116
+ bone.name = jointNames[i] || `Joint_${i}`;
117
+ bones.push(bone);
118
+ bones[parentIndex].add(bone);
119
+ console.log(`Joint ${i} (${bone.name}): parent=${parentIndex}, pos=${bone.position.toArray()}`);
120
+ } else {
121
+ console.warn(`Invalid parent index ${parentIndex} for joint ${i}, attaching to root`);
122
+ bone.position.set(0, 0, 0);
123
+ bone.name = jointNames[i] || `Joint_${i}`;
124
+ bones.push(bone);
125
+ bones[0].add(bone);
126
+ }
127
+ }
128
+
129
+ var skeleton = new THREE.Skeleton(bones);
130
+
131
+ geometry.computeVertexNormals();
132
+
133
+ // --- Texture Loading ---
134
+ const textureLoader = new THREE.TextureLoader();
135
+
136
+ async function loadTextureAsync(url, isSRGB = true) {
137
+ const tex = await textureLoader.loadAsync(url);
138
+ tex.flipY = false;
139
+ if (isSRGB) tex.colorSpace = THREE.SRGBColorSpace;
140
+ return tex;
141
+ }
142
+
143
+ const [baseColorMap] = await Promise.all([
144
+ loadTextureAsync(`${basePath}/Boy_lambert4_BaseColor.webp`, true),
145
+ ]);
146
+
147
+ // Create material - PBR with textures (optimized for dark mode)
148
+ const material = new THREE.MeshStandardMaterial({
149
+ map: baseColorMap,
150
+ roughness: 0.6, // Lower roughness for better light reflection
151
+ metalness: 0.2, // Lower metalness for more natural look
152
+ envMapIntensity: 1.5, // Enhanced environment lighting
153
+ });
154
+
155
+ var mesh = new THREE.SkinnedMesh(geometry, material);
156
+ mesh.castShadow = true;
157
+ mesh.receiveShadow = true;
158
+ mesh.add(bones[0]);
159
+ mesh.bind(skeleton);
160
+
161
+ console.log(`Wooden model loaded: ${numJoints} joints, ${v_template.length / 3} vertices`);
162
+
163
+ return { bones, skeleton, mesh, jointNames, edges };
164
+ }
165
+
166
+ export { DEFAULT_EDGES, load_wooden, NUM_SKIN_WEIGHTS, SMPLH_JOINT_NAMES };
167
+
scripts/gradio/templates/element/blank.html ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+
4
+ <head>
5
+ <title>{% block title %} {% endblock %}</title>
6
+ <meta charset="UTF-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
8
+ <!-- Add Bootstrap CSS (CDN) -->
9
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
10
+ <!-- Add jQuery and Bootstrap JS (CDN) -->
11
+ <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
12
+ <script src="https://cdn.jsdelivr.net/npm/@popperjs/core@2.10.2/dist/umd/popper.min.js"></script>
13
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.min.js"></script>
14
+ <style>
15
+ /* Dark mode base */
16
+ html, body {
17
+ background: #1a1a2e !important;
18
+ color: #e2e8f0;
19
+ margin: 0;
20
+ padding: 0;
21
+ }
22
+
23
+ .container {
24
+ padding: 0;
25
+ border: none;
26
+ background: #1a1a2e;
27
+ }
28
+
29
+ .alert-success {
30
+ display: none;
31
+ }
32
+
33
+ {% block style %}
34
+ {% endblock %}
35
+ </style>
36
+ </head>
37
+
38
+ <body>
39
+
40
+ {% block content_block %}
41
+ {% endblock %}
42
+
43
+ {% block script_block %}
44
+ {% endblock %}
45
+
46
+ <div class="container alert-success mt-3" role="alert">
47
+ {% block help %}
48
+
49
+ {% endblock %}
50
+ </div>
51
+ </body>
52
+
53
+ </html>
scripts/gradio/templates/error_file_not_found.html ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="zh-CN">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>File not found - 404</title>
8
+ <style>
9
+ * {
10
+ margin: 0;
11
+ padding: 0;
12
+ box-sizing: border-box;
13
+ }
14
+
15
+ body {
16
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Arial, sans-serif;
17
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
18
+ min-height: 100vh;
19
+ display: flex;
20
+ align-items: center;
21
+ justify-content: center;
22
+ padding: 20px;
23
+ }
24
+
25
+ .container {
26
+ background: white;
27
+ border-radius: 12px;
28
+ box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
29
+ max-width: 500px;
30
+ width: 100%;
31
+ overflow: hidden;
32
+ text-align: center;
33
+ padding: 40px;
34
+ }
35
+
36
+ h1 {
37
+ color: #e74c3c;
38
+ font-size: 4em;
39
+ margin-bottom: 20px;
40
+ }
41
+
42
+ h2 {
43
+ color: #333;
44
+ font-size: 1.5em;
45
+ margin-bottom: 15px;
46
+ }
47
+
48
+ p {
49
+ color: #666;
50
+ font-size: 1.1em;
51
+ line-height: 1.6;
52
+ }
53
+ </style>
54
+ </head>
55
+
56
+ <body>
57
+ <div class="container">
58
+ <h1>404</h1>
59
+ <h2>Oops! File Not Found</h2>
60
+ <p>We couldn't find the file you're looking for.<br>Please check the URL or try again later.</p>
61
+ </div>
62
+ </body>
63
+
64
+ </html>
scripts/gradio/templates/index_smpl_gradio.html ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends 'element/blank.html' %}
2
+
3
+ {% block content_block %}
4
+
5
+ <div class="container mt-3">
6
+ <!-- Caption container -->
7
+ {% if not hide_captions %}
8
+ <div class="caption-container">
9
+ <div class="motion-info" id="motion-info">
10
+ <div class="loading">
11
+ <i class="fas fa-spinner fa-spin"></i> Loading action descriptions...
12
+ </div>
13
+ </div>
14
+ </div>
15
+ {% endif %}
16
+
17
+ <!-- 3D container -->
18
+ <div class="vis3d-wrapper" style="position: relative;">
19
+ <div class="d-flex justify-content-center" id="vis3d">
20
+ </div>
21
+ </div>
22
+
23
+ <!-- Playback control panel -->
24
+ <div class="control-panel-embedded mt-3">
25
+ <div class="control-row-compact">
26
+ <div class="control-group">
27
+ <button id="playPauseBtn" class="control-btn" title="Play/Pause">
28
+ <i class="fas fa-play"></i>
29
+ </button>
30
+ <button id="resetBtn" class="control-btn" title="Reset to start">
31
+ <i class="fas fa-step-backward"></i>
32
+ </button>
33
+ </div>
34
+
35
+ <div class="progress-group">
36
+ <input type="range" id="progressSlider" class="progress-slider" min="0" max="100" value="0">
37
+ </div>
38
+
39
+ <div class="info-group">
40
+ <div class="frame-info">
41
+ <span id="currentFrame">0</span> / <span id="totalFrames">0</span>
42
+ </div>
43
+ <div class="loading-status" id="loadingStatus">
44
+ <i class="fas fa-spinner fa-spin"></i> Loading...
45
+ </div>
46
+ </div>
47
+
48
+ <div class="speed-group">
49
+ <label for="speedSlider">Speed:</label>
50
+ <input type="range" id="speedSlider" class="speed-slider" min="0.1" max="3" step="0.1" value="1">
51
+ <span id="speedValue">1.0x</span>
52
+ </div>
53
+ </div>
54
+ </div>
55
+ </div>
56
+
57
+ <!-- Add Font Awesome for icons -->
58
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
59
+
60
+ <script type="importmap">
61
+ {
62
+ "imports": {
63
+ "three": "https://cdn.jsdelivr.net/npm/three@0.160.0/build/three.module.js",
64
+ "three/addons/": "https://cdn.jsdelivr.net/npm/three@0.160.0/examples/jsm/"
65
+ }
66
+ }
67
+ </script>
68
+
69
+ <script type="module">
70
+ import * as THREE from 'three';
71
+ import { OrbitControls } from 'three/addons/controls/OrbitControls.js';
72
+ import { getChessboardXZ, getCoordinate } from '/static/scripts3d/create_ground.js';
73
+ import { create_scene, fitCameraToScene } from '/static/scripts3d/create_scene.js';
74
+ import { load_smpl_with_shapes } from '/static/scripts3d/load_smpl.js';
75
+
76
+ let scene, camera, renderer;
77
+ let controls;
78
+ let infos;
79
+ let currentFrame = 0;
80
+ let total_frame = 0;
81
+ const baseIntervalTime = 30;
82
+ var model_mesh = {};
83
+
84
+ let isPlaying = false;
85
+ let lastFrameTime = 0;
86
+ let playbackSpeed = 1.0;
87
+ let animationId = null;
88
+ let modelsLoaded = false;
89
+ let expectedModelCount = 0;
90
+ let loadedModelCount = 0;
91
+
92
+ let ignoreGlobalTrans = false;
93
+ let currentOffsets = [];
94
+
95
+ const updateFrame = () => {
96
+ if (!infos || currentFrame >= total_frame || !modelsLoaded) return;
97
+
98
+ const info = infos[currentFrame];
99
+ let allModelsReady = true;
100
+
101
+ info.forEach(smpl_params => {
102
+ if (!(smpl_params.id in model_mesh)) {
103
+ allModelsReady = false;
104
+ }
105
+ });
106
+
107
+ if (!allModelsReady) {
108
+ return;
109
+ }
110
+
111
+ const offsets = computeOffsets(info.length);
112
+ currentOffsets = offsets;
113
+
114
+ info.forEach((smpl_params, b) => {
115
+ const bones = model_mesh[smpl_params.id];
116
+ const mesh = bones[0].parent;
117
+
118
+ if (ignoreGlobalTrans) {
119
+ mesh.position.set(-offsets[b], 0, 0);
120
+ } else {
121
+ mesh.position.set(
122
+ smpl_params.Th[0][0] - offsets[b],
123
+ smpl_params.Th[0][1],
124
+ smpl_params.Th[0][2]
125
+ );
126
+ }
127
+
128
+ var axis = new THREE.Vector3(smpl_params.Rh[0][0], smpl_params.Rh[0][1], smpl_params.Rh[0][2]);
129
+ var angle = axis.length();
130
+ axis.normalize();
131
+
132
+ var poses_offset = 0;
133
+ if (smpl_params.poses[0].length == 69) {
134
+ poses_offset = -3;
135
+ }
136
+ for (let i = 0; i < bones.length; i++) {
137
+ var axis = new THREE.Vector3(
138
+ smpl_params.poses[0][poses_offset + 3 * i],
139
+ smpl_params.poses[0][poses_offset + 3 * i + 1],
140
+ smpl_params.poses[0][poses_offset + 3 * i + 2]);
141
+ var angle = axis.length();
142
+ axis.normalize();
143
+ var quaternion = new THREE.Quaternion().setFromAxisAngle(axis, angle);
144
+ bones[i].quaternion.copy(quaternion);
145
+ }
146
+ });
147
+
148
+ updateUI();
149
+ }
150
+
151
+ const playLoop = (currentTime) => {
152
+ if (isPlaying && currentTime - lastFrameTime >= (baseIntervalTime / playbackSpeed)) {
153
+ currentFrame += 1;
154
+ if (currentFrame >= total_frame) {
155
+ currentFrame = 0;
156
+ }
157
+ updateFrame();
158
+ lastFrameTime = currentTime;
159
+ }
160
+
161
+ if (isPlaying) {
162
+ animationId = requestAnimationFrame(playLoop);
163
+ }
164
+ }
165
+
166
+ const updateUI = () => {
167
+ document.getElementById('currentFrame').textContent = currentFrame;
168
+ document.getElementById('totalFrames').textContent = total_frame;
169
+
170
+ if (total_frame > 0) {
171
+ const progress = (currentFrame / total_frame) * 100;
172
+ document.getElementById('progressSlider').value = progress;
173
+ }
174
+ }
175
+
176
+ const updateLoadingStatus = () => {
177
+ const loadingElement = document.getElementById('loadingStatus');
178
+ if (!loadingElement) return;
179
+
180
+ if (modelsLoaded) {
181
+ loadingElement.innerHTML = '<i class="fas fa-check"></i> Model loaded';
182
+ loadingElement.className = 'loading-status complete';
183
+ setTimeout(() => {
184
+ loadingElement.className = 'loading-status hidden';
185
+ }, 3000);
186
+ } else {
187
+ loadingElement.innerHTML = `<i class="fas fa-spinner fa-spin"></i> Loading models... (${loadedModelCount}/${expectedModelCount})`;
188
+ loadingElement.className = 'loading-status';
189
+ }
190
+ }
191
+
192
+ const updatePlayPauseButton = () => {
193
+ const playPauseBtn = document.getElementById('playPauseBtn');
194
+ if (playPauseBtn) {
195
+ if (isPlaying) {
196
+ playPauseBtn.innerHTML = '<i class="fas fa-pause"></i>';
197
+ playPauseBtn.title = 'Pause';
198
+ } else {
199
+ playPauseBtn.innerHTML = '<i class="fas fa-play"></i>';
200
+ playPauseBtn.title = 'Play';
201
+ }
202
+ }
203
+ }
204
+
205
+ const enablePlaybackControls = () => {
206
+ const playPauseBtn = document.getElementById('playPauseBtn');
207
+ const resetBtn = document.getElementById('resetBtn');
208
+ const progressSlider = document.getElementById('progressSlider');
209
+ const speedSlider = document.getElementById('speedSlider');
210
+
211
+ [playPauseBtn, resetBtn, progressSlider, speedSlider].forEach(element => {
212
+ if (element) {
213
+ element.disabled = false;
214
+ element.style.opacity = '1';
215
+ element.style.cursor = 'pointer';
216
+ }
217
+ });
218
+
219
+ updatePlayPauseButton();
220
+ }
221
+
222
+ const playAnimation = () => {
223
+ if (!isPlaying && total_frame > 0 && modelsLoaded) {
224
+ isPlaying = true;
225
+ lastFrameTime = performance.now();
226
+ animationId = requestAnimationFrame(playLoop);
227
+ updatePlayPauseButton();
228
+ }
229
+ }
230
+
231
+ const pauseAnimation = () => {
232
+ isPlaying = false;
233
+ if (animationId) {
234
+ cancelAnimationFrame(animationId);
235
+ animationId = null;
236
+ }
237
+ updatePlayPauseButton();
238
+ }
239
+
240
+ const resetAnimation = () => {
241
+ pauseAnimation();
242
+ currentFrame = 0;
243
+ updateFrame();
244
+ updatePlayPauseButton();
245
+ }
246
+
247
+ const initPlaybackControls = () => {
248
+ const playPauseBtn = document.getElementById('playPauseBtn');
249
+ const resetBtn = document.getElementById('resetBtn');
250
+ const progressSlider = document.getElementById('progressSlider');
251
+ const speedSlider = document.getElementById('speedSlider');
252
+ const speedValue = document.getElementById('speedValue');
253
+
254
+ [playPauseBtn, resetBtn, progressSlider, speedSlider].forEach(element => {
255
+ if (element) {
256
+ element.disabled = true;
257
+ element.style.opacity = '0.5';
258
+ element.style.cursor = 'not-allowed';
259
+ }
260
+ });
261
+
262
+ updatePlayPauseButton();
263
+
264
+ playPauseBtn.addEventListener('click', () => {
265
+ if (!modelsLoaded) return;
266
+ if (isPlaying) {
267
+ pauseAnimation();
268
+ } else {
269
+ playAnimation();
270
+ }
271
+ });
272
+
273
+ resetBtn.addEventListener('click', () => {
274
+ if (!modelsLoaded) return;
275
+ resetAnimation();
276
+ });
277
+
278
+ let wasPlaying = false;
279
+ progressSlider.addEventListener('mousedown', () => {
280
+ if (!modelsLoaded) return;
281
+ wasPlaying = isPlaying;
282
+ if (isPlaying) pauseAnimation();
283
+ });
284
+
285
+ progressSlider.addEventListener('input', (e) => {
286
+ if (!modelsLoaded) return;
287
+ const progress = parseFloat(e.target.value);
288
+ currentFrame = Math.floor((progress / 100) * total_frame);
289
+ if (currentFrame >= total_frame) currentFrame = total_frame - 1;
290
+ if (currentFrame < 0) currentFrame = 0;
291
+ updateFrame();
292
+ });
293
+
294
+ progressSlider.addEventListener('mouseup', () => {
295
+ if (!modelsLoaded) return;
296
+ if (wasPlaying) playAnimation();
297
+ });
298
+
299
+ speedSlider.addEventListener('input', (e) => {
300
+ playbackSpeed = parseFloat(e.target.value);
301
+ speedValue.textContent = playbackSpeed.toFixed(1) + 'x';
302
+ });
303
+
304
+ document.addEventListener('keydown', (e) => {
305
+ if (!modelsLoaded) return;
306
+ switch (e.code) {
307
+ case 'Space':
308
+ e.preventDefault();
309
+ if (isPlaying) {
310
+ pauseAnimation();
311
+ } else {
312
+ playAnimation();
313
+ }
314
+ break;
315
+ case 'ArrowLeft':
316
+ e.preventDefault();
317
+ if (currentFrame > 0) {
318
+ currentFrame--;
319
+ updateFrame();
320
+ }
321
+ break;
322
+ case 'ArrowRight':
323
+ e.preventDefault();
324
+ if (currentFrame < total_frame - 1) {
325
+ currentFrame++;
326
+ updateFrame();
327
+ }
328
+ break;
329
+ case 'Home':
330
+ e.preventDefault();
331
+ resetAnimation();
332
+ break;
333
+ }
334
+ });
335
+ }
336
+
337
+ async function waitAndFetchData() {
338
+ try {
339
+ const waitResponse = await fetch('/wait_for_data');
340
+ const waitResult = await waitResponse.json();
341
+
342
+ if (waitResult.status === 'ready') {
343
+ console.log(`Data ready with ${waitResult.frames} frames`);
344
+ fetchSMPLData();
345
+ } else {
346
+ console.log('Timeout waiting for data, trying direct fetch...');
347
+ fetchSMPLData();
348
+ }
349
+ } catch (error) {
350
+ console.error('Error waiting for data:', error);
351
+ fetchSMPLData();
352
+ }
353
+ }
354
+
355
+ waitAndFetchData();
356
+
357
+ function fetchSMPLData() {
358
+ fetch('/query_smpl/{{ folder_name }}/{{ file_name }}')
359
+ .then(response => response.json())
360
+ .then(async datas => {
361
+ if (!datas || datas.length === 0) {
362
+ console.log('No data received, retrying in 2 seconds...');
363
+ setTimeout(fetchSMPLData, 2000);
364
+ return;
365
+ }
366
+
367
+ console.log(`Received ${datas.length} frames of SMPL data`);
368
+ infos = datas;
369
+ total_frame = datas.length;
370
+
371
+ updateGroundWithData(datas);
372
+ document.getElementById('progressSlider').max = 100;
373
+ updateUI();
374
+ updatePlayPauseButton();
375
+
376
+ expectedModelCount = infos[0].length;
377
+
378
+ loadedModelCount = 0;
379
+ modelsLoaded = false;
380
+ updateLoadingStatus();
381
+
382
+ infos[0].forEach(data => {
383
+ load_smpl_with_shapes(data.shapes[0], data.gender).then(result => {
384
+ scene.add(result.mesh);
385
+
386
+ const skeletonHelper = new THREE.SkeletonHelper(result.mesh);
387
+ scene.add(skeletonHelper);
388
+
389
+ model_mesh[data.id] = result.bones;
390
+
391
+ loadedModelCount++;
392
+
393
+ if (loadedModelCount === expectedModelCount) {
394
+ modelsLoaded = true;
395
+ updateLoadingStatus();
396
+ updateFrame();
397
+ enablePlaybackControls();
398
+ fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
399
+ setTimeout(() => playAnimation(), 500);
400
+ } else {
401
+ updateLoadingStatus();
402
+ }
403
+ }).catch(err => {
404
+ console.error("Failed to load SMPL model:", err);
405
+ });
406
+ });
407
+
408
+ initPlaybackControls();
409
+ animate();
410
+ });
411
+ }
412
+
413
+ function updateGroundWithData(smplData) {
414
+ const sampleData = smplData.map(frame => {
415
+ const offsets = computeOffsets(frame.length);
416
+ return {
417
+ positions: frame.map((person, b) => ({
418
+ x: person.Th[0][0] - offsets[b],
419
+ y: person.Th[0][1],
420
+ z: person.Th[0][2]
421
+ }))
422
+ };
423
+ });
424
+
425
+ const objectsToRemove = [];
426
+ scene.traverse((child) => {
427
+ if (child.isMesh && child.geometry &&
428
+ (child.geometry.type === 'PlaneGeometry' || child.name === 'ground')) {
429
+ objectsToRemove.push(child);
430
+ }
431
+ });
432
+
433
+ objectsToRemove.forEach(obj => {
434
+ scene.remove(obj);
435
+ if (obj.geometry) obj.geometry.dispose();
436
+ if (obj.material) {
437
+ if (obj.material.map) obj.material.map.dispose();
438
+ obj.material.dispose();
439
+ }
440
+ });
441
+
442
+ const adaptiveGround = getChessboardXZ(5, 10, '#ffffff', '#444444', 1024, sampleData);
443
+ adaptiveGround.name = 'ground';
444
+ scene.add(adaptiveGround);
445
+ }
446
+
447
+ init();
448
+
449
+ function init() {
450
+ const width = document.querySelector('.container').offsetWidth;
451
+ const height = width * 13 / 16;
452
+ scene = new THREE.Scene();
453
+ camera = new THREE.PerspectiveCamera(60, width / height, 0.01, 100);
454
+ renderer = new THREE.WebGLRenderer({ antialias: true, logarithmicDepthBuffer: true });
455
+
456
+ create_scene(scene, camera, renderer, true, 'y', 'z');
457
+ renderer.setPixelRatio(window.devicePixelRatio);
458
+ renderer.setSize(width, height);
459
+ var container = document.getElementById('vis3d');
460
+ container.appendChild(renderer.domElement);
461
+
462
+ controls = new OrbitControls(camera, renderer.domElement);
463
+ controls.minDistance = 1;
464
+ controls.maxDistance = 10;
465
+ fitCameraToScene(scene, camera, controls, { axis_up: 'y', excludeNames: ['ground'] });
466
+ }
467
+
468
+ function animate() {
469
+ requestAnimationFrame(animate);
470
+ renderer.render(scene, camera);
471
+ }
472
+
473
+ function computeOffsets(batchSize) {
474
+ const spacing = 2.0;
475
+ const total_width = (batchSize - 1) * spacing;
476
+ const start_x = -total_width / 2;
477
+ const offsets = [];
478
+ for (let i = 0; i < batchSize; i++) {
479
+ offsets.push(start_x + i * spacing);
480
+ }
481
+ return offsets;
482
+ }
483
+
484
+ </script>
485
+
486
+ <style>
487
+ /* Navigation Controls */
488
+ .nav-controls {
489
+ position: absolute;
490
+ top: 10px;
491
+ left: 10px;
492
+ display: flex;
493
+ gap: 10px;
494
+ z-index: 1000;
495
+ }
496
+
497
+ .nav-btn {
498
+ background: rgba(102, 126, 234, 0.9);
499
+ color: white;
500
+ border: none;
501
+ border-radius: 6px;
502
+ padding: 8px 12px;
503
+ font-size: 12px;
504
+ cursor: pointer;
505
+ transition: all 0.3s ease;
506
+ display: flex;
507
+ align-items: center;
508
+ gap: 4px;
509
+ }
510
+
511
+ .nav-btn:hover {
512
+ background: rgba(102, 126, 234, 1);
513
+ transform: translateY(-1px);
514
+ }
515
+
516
+ .container {
517
+ position: relative;
518
+ }
519
+
520
+ /* Control Panel Styles */
521
+ .control-panel-embedded {
522
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
523
+ border-radius: 12px;
524
+ padding: 15px;
525
+ color: white;
526
+ margin-top: 15px;
527
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.2);
528
+ }
529
+
530
+ .control-row-compact {
531
+ display: flex;
532
+ align-items: center;
533
+ justify-content: space-between;
534
+ gap: 15px;
535
+ margin-bottom: 0;
536
+ }
537
+
538
+ .control-group {
539
+ display: flex;
540
+ gap: 10px;
541
+ flex-shrink: 0;
542
+ }
543
+
544
+ .progress-group {
545
+ flex: 1;
546
+ margin: 0 15px;
547
+ }
548
+
549
+ .info-group {
550
+ display: flex;
551
+ align-items: center;
552
+ gap: 10px;
553
+ flex-shrink: 0;
554
+ }
555
+
556
+ .speed-group {
557
+ display: flex;
558
+ align-items: center;
559
+ gap: 8px;
560
+ flex-shrink: 0;
561
+ }
562
+
563
+ .control-btn {
564
+ background: rgba(255, 255, 255, 0.2);
565
+ border: 2px solid rgba(255, 255, 255, 0.3);
566
+ border-radius: 50px;
567
+ width: 50px;
568
+ height: 50px;
569
+ color: white;
570
+ font-size: 18px;
571
+ cursor: pointer;
572
+ transition: all 0.3s ease;
573
+ display: flex;
574
+ align-items: center;
575
+ justify-content: center;
576
+ }
577
+
578
+ .control-btn:hover {
579
+ background: rgba(255, 255, 255, 0.3);
580
+ border-color: rgba(255, 255, 255, 0.6);
581
+ transform: scale(1.1);
582
+ }
583
+
584
+ .control-btn:active {
585
+ transform: scale(0.95);
586
+ }
587
+
588
+ .control-btn:disabled,
589
+ .progress-slider:disabled,
590
+ .speed-slider:disabled {
591
+ pointer-events: none;
592
+ opacity: 0.5 !important;
593
+ cursor: not-allowed !important;
594
+ }
595
+
596
+ .frame-info {
597
+ background: rgba(255, 255, 255, 0.2);
598
+ padding: 6px 12px;
599
+ border-radius: 20px;
600
+ font-family: 'Courier New', monospace;
601
+ font-weight: bold;
602
+ font-size: 14px;
603
+ }
604
+
605
+ .loading-status {
606
+ background: rgba(255, 255, 255, 0.2);
607
+ padding: 6px 12px;
608
+ border-radius: 20px;
609
+ font-size: 12px;
610
+ color: #ffeb3b;
611
+ display: flex;
612
+ align-items: center;
613
+ gap: 6px;
614
+ }
615
+
616
+ .loading-status.hidden {
617
+ display: none;
618
+ }
619
+
620
+ .loading-status.complete {
621
+ color: #4caf50;
622
+ }
623
+
624
+ .progress-slider {
625
+ width: 100%;
626
+ height: 6px;
627
+ border-radius: 3px;
628
+ background: rgba(255, 255, 255, 0.3);
629
+ outline: none;
630
+ cursor: pointer;
631
+ }
632
+
633
+ .progress-slider::-webkit-slider-thumb {
634
+ appearance: none;
635
+ width: 20px;
636
+ height: 20px;
637
+ border-radius: 50%;
638
+ background: white;
639
+ cursor: pointer;
640
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
641
+ }
642
+
643
+ .progress-slider::-moz-range-thumb {
644
+ width: 20px;
645
+ height: 20px;
646
+ border-radius: 50%;
647
+ background: white;
648
+ cursor: pointer;
649
+ border: none;
650
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3);
651
+ }
652
+
653
+ .speed-slider {
654
+ width: 80px;
655
+ height: 4px;
656
+ border-radius: 2px;
657
+ background: rgba(255, 255, 255, 0.3);
658
+ outline: none;
659
+ cursor: pointer;
660
+ }
661
+
662
+ .speed-slider::-webkit-slider-thumb {
663
+ appearance: none;
664
+ width: 16px;
665
+ height: 16px;
666
+ border-radius: 50%;
667
+ background: white;
668
+ cursor: pointer;
669
+ }
670
+
671
+ .speed-slider::-moz-range-thumb {
672
+ width: 16px;
673
+ height: 16px;
674
+ border-radius: 50%;
675
+ background: white;
676
+ cursor: pointer;
677
+ border: none;
678
+ }
679
+
680
+ #speedValue {
681
+ font-family: 'Courier New', monospace;
682
+ font-weight: bold;
683
+ min-width: 40px;
684
+ }
685
+
686
+ /* Responsive design */
687
+ @media (max-width: 768px) {
688
+ .control-row-compact {
689
+ flex-direction: column;
690
+ gap: 10px;
691
+ }
692
+
693
+ .progress-group {
694
+ margin: 10px 0;
695
+ }
696
+
697
+ .info-group,
698
+ .speed-group {
699
+ justify-content: center;
700
+ }
701
+ }
702
+
703
+ /* Original styles */
704
+ .caption-container {
705
+ margin-top: 10px;
706
+ margin-bottom: 10px;
707
+ width: 100%;
708
+ }
709
+
710
+ .motion-info {
711
+ background-color: #ffffff;
712
+ border-radius: 10px;
713
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
714
+ overflow: hidden;
715
+ }
716
+
717
+ .loading {
718
+ padding: 20px;
719
+ text-align: center;
720
+ color: #666;
721
+ font-style: italic;
722
+ }
723
+
724
+ .file-info {
725
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
726
+ color: white;
727
+ padding: 20px;
728
+ }
729
+
730
+ .file-info h4 {
731
+ margin: 0 0 10px 0;
732
+ font-size: 1.2em;
733
+ font-weight: 600;
734
+ }
735
+
736
+ .file-detail {
737
+ margin: 5px 0;
738
+ font-size: 0.9em;
739
+ opacity: 0.9;
740
+ }
741
+
742
+ .captions-section {
743
+ padding: 20px;
744
+ }
745
+
746
+ .captions-title {
747
+ margin: 0 0 15px 0;
748
+ color: #333;
749
+ font-size: 1.1em;
750
+ font-weight: 600;
751
+ border-bottom: 2px solid #667eea;
752
+ padding-bottom: 8px;
753
+ }
754
+
755
+ .caption-item {
756
+ background-color: #f8f9ff;
757
+ border: 1px solid #e1e5f7;
758
+ border-radius: 8px;
759
+ margin-bottom: 15px;
760
+ padding: 15px;
761
+ transition: all 0.3s ease;
762
+ }
763
+
764
+ .caption-item:hover {
765
+ transform: translateY(-2px);
766
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.15);
767
+ }
768
+
769
+ .caption-content {
770
+ margin-bottom: 10px;
771
+ }
772
+
773
+ .no-captions {
774
+ text-align: center;
775
+ padding: 30px;
776
+ color: #666;
777
+ font-style: italic;
778
+ }
779
+
780
+ .error-message {
781
+ background-color: #fee;
782
+ color: #c33;
783
+ padding: 15px;
784
+ border-radius: 5px;
785
+ text-align: center;
786
+ }
787
+
788
+ .caption-rewritten {
789
+ font-size: 1em;
790
+ font-weight: 600;
791
+ color: #333;
792
+ margin-bottom: 5px;
793
+ }
794
+
795
+ .caption-original-toggle {
796
+ font-size: 0.85em;
797
+ color: #667eea;
798
+ cursor: pointer;
799
+ display: inline-flex;
800
+ align-items: center;
801
+ gap: 4px;
802
+ padding: 4px 0;
803
+ user-select: none;
804
+ transition: color 0.2s;
805
+ }
806
+
807
+ .caption-original-toggle:hover {
808
+ color: #764ba2;
809
+ }
810
+
811
+ .caption-original-toggle i {
812
+ transition: transform 0.2s;
813
+ }
814
+
815
+ .caption-original-toggle.expanded i {
816
+ transform: rotate(180deg);
817
+ }
818
+
819
+ .caption-original {
820
+ display: none !important;
821
+ font-size: 0.9em;
822
+ color: #666;
823
+ background: #f0f0f5;
824
+ padding: 8px 12px;
825
+ border-radius: 6px;
826
+ margin-top: 6px;
827
+ line-height: 1.4;
828
+ }
829
+
830
+ .caption-original.show {
831
+ display: block !important;
832
+ }
833
+
834
+ .original-label {
835
+ font-weight: 600;
836
+ color: #888;
837
+ }
838
+ </style>
839
+
840
+
841
+ <!-- Only load caption fetching script if captions are not hidden -->
842
+ {% if not hide_captions %}
843
+ <script>
844
+ function createCaptionItem(caption, index) {
845
+ const hasOriginal = caption['short caption'];
846
+ const rewritten = caption['short caption+'] || caption['short caption'] || 'No caption';
847
+
848
+ return `
849
+ <div class="caption-item">
850
+ <div class="caption-content">
851
+ <div class="caption-rewritten">${rewritten}</div>
852
+ ${hasOriginal && caption['short caption+'] ? `
853
+ <div class="caption-original-toggle" data-toggle="caption-original">
854
+ <i class="fas fa-chevron-down"></i> Show original text
855
+ </div>
856
+ <div class="caption-original">
857
+ <span class="original-label">Original:</span> ${caption['short caption']}
858
+ </div>
859
+ ` : ''}
860
+ </div>
861
+ </div>
862
+ `;
863
+ }
864
+
865
+ document.addEventListener('click', (e) => {
866
+ const toggle = e.target.closest('[data-toggle="caption-original"]');
867
+ if (toggle) {
868
+ const originalDiv = toggle.nextElementSibling;
869
+ if (originalDiv && originalDiv.classList.contains('caption-original')) {
870
+ originalDiv.classList.toggle('show');
871
+ toggle.classList.toggle('expanded');
872
+ }
873
+ }
874
+ });
875
+
876
+ function fetchMotionInfo() {
877
+ fetch('/query_caption/{{ folder_name }}/{{ file_name }}')
878
+ .then(response => response.json())
879
+ .then(data => {
880
+ const motionInfoElement = document.getElementById('motion-info');
881
+
882
+ if (data && (data.filename || data.motion_path || data.result)) {
883
+ let html = '';
884
+
885
+ if (data.filename || data.motion_path) {
886
+ html += `
887
+ <div class="file-info">
888
+ <h4>File Information</h4>
889
+ ${data.filename ? `<div class="file-detail"><strong>Filename:</strong> ${data.filename}</div>` : ''}
890
+ ${data.motion_path ? `<div class="file-detail"><strong>Motion Path:</strong> ${data.motion_path}</div>` : ''}
891
+ </div>
892
+ `;
893
+ }
894
+
895
+ if (data.result && Array.isArray(data.result) && data.result.length > 0) {
896
+ html += `
897
+ <div class="captions-section">
898
+ <h5 class="captions-title">Motion Captions (${data.result.length})</h5>
899
+ ${data.result.map((caption, index) => createCaptionItem(caption, index)).join('')}
900
+ </div>
901
+ `;
902
+ } else if (data.result) {
903
+ html += `
904
+ <div class="captions-section">
905
+ <div class="no-captions">No captions available for this motion</div>
906
+ </div>
907
+ `;
908
+ }
909
+
910
+ if (!html && data.caption) {
911
+ html = `
912
+ <div class="captions-section">
913
+ <h5 class="captions-title">Caption</h5>
914
+ <div class="caption-item">
915
+ <div class="caption-content">
916
+ <div class="caption-short">${data.caption}</div>
917
+ </div>
918
+ </div>
919
+ </div>
920
+ `;
921
+ }
922
+
923
+ motionInfoElement.innerHTML = html || '<div class="no-captions">No motion information available</div>';
924
+ } else {
925
+ motionInfoElement.innerHTML = '<div class="no-captions">No motion information available</div>';
926
+ }
927
+ })
928
+ .catch(error => {
929
+ console.error('Error fetching motion information:', error);
930
+ document.getElementById('motion-info').innerHTML = '<div class="error-message">Error loading motion information</div>';
931
+ });
932
+ }
933
+
934
+ document.addEventListener('DOMContentLoaded', fetchMotionInfo);
935
+ </script>
936
+ {% endif %}
937
+
938
+ {% endblock %}