diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..bd74e467838c1a217a809dad22755561f46f9a69 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +3rdparty/fbxsdkpy-2020.1.post2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text +assets/wooden_models/*.fbx filter=lfs diff=lfs merge=lfs -text +assets/wooden_models/boy_Rigging_smplx_tex.fbm/Boy_lambert4_BaseColor.png filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +*.whl filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f01cc9fe8d167b3f327abedfccae4078aa7d73dd --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +cache_config.json +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz +cache + +.vscode/ + +ckpts/* +!ckpts/README.md +assets/body_models/* +!assets/body_models/README.md +scripts/gradio/static/assets/dump_smplh +scripts/gradio/static/assets/export_wooden_to_js.py + +test_*/ +debug +tencent +output diff --git a/configs/base/config.yml b/configs/base/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c581702d9bcd04e29475283b859180fee5f3e26f --- /dev/null +++ b/configs/base/config.yml @@ -0,0 +1,37 @@ +network_module: hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT +network_module_args: + apply_rope_to_single_branch: false + ctxt_input_dim: 4096 + dropout: 0.0 + feat_dim: 1024 + input_dim: 201 + mask_mode: narrowband + mlp_ratio: 4.0 + num_heads: 16 + num_layers: 18 + time_factor: 1000.0 + vtxt_input_dim: 768 +train_pipeline: hymotion/pipeline/motion_diffusion.MotionFlowMatching +train_pipeline_args: + enable_ctxt_null_feat: true + enable_special_game_feat: true + infer_noise_scheduler_cfg: + validation_steps: 50 + losses_cfg: + recons: + name: SmoothL1Loss + weight: 1.0 + noise_scheduler_cfg: + method: euler + output_mesh_fps: 30 + random_generator_on_gpu: true + test_cfg: + mean_std_dir: ./stats/ + text_guidance_scale: 5.0 + text_encoder_cfg: + llm_type: qwen3 + max_length_llm: 128 + text_encoder_module: hymotion/network/text_encoders/text_encoder.HYTextModel + train_cfg: + cond_mask_prob: 0.1 + train_frames: 360 diff --git a/examples/example_prompts/example_subset.json b/examples/example_prompts/example_subset.json new file mode 100644 index 0000000000000000000000000000000000000000..c125425324b57405c2a40ff1c4973bfb0508a727 --- /dev/null +++ b/examples/example_prompts/example_subset.json @@ -0,0 +1,61 @@ +{ + "test_prompts_subset": [ + "A person jumps upward with both legs twice.#90#none#001", + "A person jumps on their right leg.#90#none#002", + "A person climbs upward, moving up the slope.#60#none#003", + "A person climbs an obstacle.#60#none#004", + "A person walks forward.#120#none#005", + "A person walks forward, moving arms and legs while looking left and right.#180#none#006", + "A person walks unsteadily, then slowly sits down.#150#none#007", + "A person turns backward 180 degrees, then walks forward.#120#none#008", + "A person walks in a catwalk style, swinging their left arm while placing their right hand on their hip.#180#none#009", + "A person squats down on tiptoe#120#none#010", + "A person sits down on a chair.#90#none#011", + "A person runs forward.#60#none#012", + "A person jumps up.#90#none#013", + "A person jumps forward lightly, taking two steps.#69#none#014", + "A person shoots a basketball.#60#none#015", + "A person finishes freestyle swimming, then surfaces.#120#none#016", + "A person swings a golf club, hitting the ball forward.#111#none#017", + "A person runs forward, then kicks a soccer ball.#60#none#018", + "A person walks on a tightrope.#180#none#019", + "A person performs a yoga camel pose, extending their back and lifting their chest.#210#none#020", + "A person performs a sit-up, holding their head with both hands.#150#none#021", + "A person performs a lunge stretch, hands on hips.#150#none#022", + "A person performs a deadlift, lifting a barbell from the ground.#150#none#023", + "A person marches in place, swinging their arms forward and backward.#210#none#024", + "A person perform a squat, not standing up#93#none#025", + "A person performs a squat#93#none#026", + "A person performs a front arm raise, then does a squat.#93#none#027", + "A person performs a squat, raising both arms forward.#240#none#028", + "A person does a squat, balling both hands into fists, lowering into a squat, then standing up.#195#none#029", + "A person plays the piano.#270#none#030", + "A person dances bachata, executing rhythmic hip movements and footwork.#240#none#031", + "A person plays the drums while sitting down, with wide, crossing arm movements.#90#none#032", + "A person plays the drums while sitting down, with arms spreading wide and then crossing over.#90#none#033", + "A person dances jazz, jumping rhythmically.#240#none#034", + "A person practices tai chi, performing slow, controlled movements.#270#none#035", + "A person waves their right hand, sitting on a beach chair.#71#none#036", + "A person was sweeping the floor with their head down.#180#none#037", + "A person picks up an object from ground#117#none#038", + "A person picks up an object from lower ground with two hands#99#none#039", + "A person picks up an object from lower ground with two hands, and lifts over head#126#none#040", + "A person speaks, gesturing with both hands.#75#none#041", + "A person lies on a bed, reading a book.#180#none#042", + "A person bends down to pick up an object, then stands up straight.#150#none#043", + "A person flips the wok#61#none#044", + "A person rolls over while lying down.#60#none#045", + "A person walks forward, holding a tray at shoulder height with one hand.#93#none#046", + "A person stands up from the chair, then stretches the arms.#300#none#047", + "A person turns to evade.#61#none#048", + "A person collapses to the ground after being hit.#60#none#049", + "A person swings a sword forward.#60#none#050", + "A person attacks, holding a shield in the right hand and a sword in the left.#45#none#051", + "A person walks like a zombie, dragging their feet forward.#120#none#052", + "A person performs a taekwondo kick, extending their leg forcefully.#60#none#053", + "A person blocks with a shield.#60#none#054", + "A person lifts a long gun, then walks forward slowly.#90#none#055", + "A person stumbles, being hit.#45#none#056", + "A person assumes a boxing stance, then shifts weight to the right and punches with the right hand.#60#none#057" + ] +} \ No newline at end of file diff --git a/gradio_app.py b/gradio_app.py index 305bd9cf9b4250878a2e04dcbf754c6572b32ed4..6a888322e8c5e76fbc92b17ff3840a7933790773 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -1,28 +1,764 @@ +# we should use gradio==5.38.2 +import argparse +import codecs as cs +import json +import os +import os.path as osp +import random +import re +import textwrap +from typing import List, Optional, Tuple, Union + import gradio as gr -import sys - -try: - import torch - torch_version = torch.__version__ -except ImportError: - print("torch not found, please install it") - torch_version = "not found" - -try: - import fbx - try: - fbx_version = fbx.__version__ - except AttributeError: - # fbx module doesn't have __version__ attribute - fbx_version = "installed (version unknown)" -except ImportError: - print("fbx not found, please install it") - fbx_version = "not found" - -def greet(name): - python_version = sys.version - version = torch_version + " fbx version: " + fbx_version - return "Hello " + name + "!! torch version: " + version + " python version: " + python_version - -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() \ No newline at end of file +import torch + +from hymotion.utils.t2m_runtime import T2MRuntime + +NUM_WORKERS = torch.cuda.device_count() if torch.cuda.is_available() else 1 + + +# define data sources +DATA_SOURCES = { + "example_prompts": "examples/example_prompts/example_subset.json", +} + +# create interface +APP_CSS = """ + :root{ + --primary-start:#667eea; --primary-end:#764ba2; + --secondary-start:#4facfe; --secondary-end:#00f2fe; + --accent-start:#f093fb; --accent-end:#f5576c; + --page-bg:linear-gradient(135deg,#f5f7fa 0%,#c3cfe2 100%); + --card-bg:linear-gradient(135deg,#ffffff 0%,#f8f9fa 100%); + --radius:12px; + --iframe-bg:#ffffff; + } + + /* Dark mode variables */ + [data-theme="dark"], .dark { + --page-bg:linear-gradient(135deg,#1a1a1a 0%,#2d3748 100%); + --card-bg:linear-gradient(135deg,#2d3748 0%,#374151 100%); + --text-primary:#f7fafc; + --text-secondary:#e2e8f0; + --border-color:#4a5568; + --input-bg:#374151; + --input-border:#4a5568; + --iframe-bg:#1a1a2e; + } + + /* Page and card */ + .gradio-container{ + background:var(--page-bg) !important; + min-height:100vh !important; + color:var(--text-primary, #333) !important; + } + + .main-header{ + background:transparent !important; border:none !important; box-shadow:none !important; + padding:0 !important; margin:10px 0 16px !important; + text-align:center !important; + } + + .main-header h1, .main-header p, .main-header li { + color:var(--text-primary, #333) !important; + } + + .left-panel,.right-panel{ + background:var(--card-bg) !important; + border:1px solid var(--border-color, #e9ecef) !important; + border-radius:15px !important; + box-shadow:0 4px 20px rgba(0,0,0,.08) !important; + padding:24px !important; + } + + .gradio-accordion{ + border:1px solid var(--border-color, #e1e5e9) !important; + border-radius:var(--radius) !important; + margin:12px 0 !important; background:transparent !important; + } + + .gradio-accordion summary{ + background:transparent !important; + padding:14px 18px !important; + font-weight:600 !important; + color:var(--text-primary, #495057) !important; + } + + .gradio-group{ + background:transparent !important; border:none !important; + border-radius:8px !important; padding:12px 0 !important; margin:8px 0 !important; + } + + /* Input class style - dark mode adaptation */ + .gradio-textbox input,.gradio-textbox textarea,.gradio-dropdown .wrap{ + border-radius:8px !important; + border:2px solid var(--input-border, #e9ecef) !important; + background:var(--input-bg, #fff) !important; + color:var(--text-primary, #333) !important; + transition:.2s all !important; + } + + .gradio-textbox input:focus,.gradio-textbox textarea:focus,.gradio-dropdown .wrap:focus-within{ + border-color:var(--primary-start) !important; + box-shadow:0 0 0 3px rgba(102,126,234,.1) !important; + } + + .gradio-slider input[type="range"]{ + background:linear-gradient(to right,var(--primary-start),var(--primary-end)) !important; + border-radius:10px !important; + } + + .gradio-checkbox input[type="checkbox"]{ + border-radius:4px !important; + border:2px solid var(--input-border, #e9ecef) !important; + transition:.2s all !important; + } + + .gradio-checkbox input[type="checkbox"]:checked{ + background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; + border-color:var(--primary-start) !important; + } + + /* Label text color adaptation */ + .gradio-textbox label, .gradio-dropdown label, .gradio-slider label, + .gradio-checkbox label, .gradio-html label { + color:var(--text-primary, #333) !important; + } + + .gradio-textbox .info, .gradio-dropdown .info, .gradio-slider .info, + .gradio-checkbox .info { + color:var(--text-secondary, #666) !important; + } + + /* Status information - dark mode adaptation */ + .gradio-textbox[data-testid*="状态信息"] input{ + background:var(--input-bg, linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%)) !important; + border:2px solid var(--input-border, #dee2e6) !important; + color:var(--text-primary, #495057) !important; + font-weight:500 !important; + } + + /* Button base class and variant */ + .generate-button,.rewrite-button,.dice-button{ + border:none !important; color:#fff !important; font-weight:600 !important; + border-radius:8px !important; transition:.3s all !important; + box-shadow:0 4px 15px rgba(0,0,0,.12) !important; + } + + .generate-button{ background:linear-gradient(45deg,var(--primary-start),var(--primary-end)) !important; } + .rewrite-button{ background:linear-gradient(45deg,var(--secondary-start),var(--secondary-end)) !important; } + .dice-button{ + background:linear-gradient(45deg,var(--accent-start),var(--accent-end)) !important; + height:40px !important; + } + + .generate-button:hover,.rewrite-button:hover{ transform:translateY(-2px) !important; } + .dice-button:hover{ + transform:scale(1.05) !important; + box-shadow:0 4px 12px rgba(240,147,251,.28) !important; + } + + .dice-container{ + display:flex !important; + align-items:flex-end !important; + justify-content:center !important; + } + + /* Right panel clipping overflow, avoid double scrollbars */ + .right-panel{ + background:var(--card-bg) !important; + border:1px solid var(--border-color, #e9ecef) !important; + border-radius:15px !important; + box-shadow:0 4px 20px rgba(0,0,0,.08) !important; + padding:24px !important; overflow:hidden !important; + } + + /* Main content row - ensure equal heights */ + .main-row { + display: flex !important; + align-items: stretch !important; + } + + /* Flask area - match left panel height */ + .flask-display{ + padding:0 !important; margin:0 !important; border:none !important; + box-shadow:none !important; background:var(--iframe-bg) !important; + border-radius:10px !important; position:relative !important; + height:100% !important; min-height:750px !important; + display:flex !important; flex-direction:column !important; + } + + .flask-display iframe{ + width:100% !important; flex:1 !important; min-height:750px !important; + border:none !important; border-radius:10px !important; display:block !important; + background:var(--iframe-bg) !important; + } + + /* Right panel should stretch to match left panel */ + .right-panel{ + background:var(--card-bg) !important; + border:1px solid var(--border-color, #e9ecef) !important; + border-radius:15px !important; + box-shadow:0 4px 20px rgba(0,0,0,.08) !important; + padding:24px !important; overflow:hidden !important; + display:flex !important; flex-direction:column !important; + } + + /* Ensure dropdown menu is visible in dark mode */ + [data-theme="dark"] .gradio-dropdown .wrap, + .dark .gradio-dropdown .wrap { + background:var(--input-bg) !important; + color:var(--text-primary) !important; + } + + [data-theme="dark"] .gradio-dropdown .option, + .dark .gradio-dropdown .option { + background:var(--input-bg) !important; + color:var(--text-primary) !important; + } + + [data-theme="dark"] .gradio-dropdown .option:hover, + .dark .gradio-dropdown .option:hover { + background:var(--border-color) !important; + } + + .footer{ + text-align:center !important; + margin-top:20px !important; + padding:10px !important; + color:var(--text-secondary, #666) !important; + } +""" + +HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground" + +FOOTER_MD = "*This is a Beta version, any issues or feedback are welcome!*" + + +def load_examples_from_txt(txt_path: str): + """Load examples from txt file.""" + + def _parse_line(line: str) -> Optional[Tuple[str, float]]: + line = line.strip() + if line and not line.startswith("#"): + parts = line.split("#") + if len(parts) >= 2: + text = parts[0].strip() + duration = int(parts[1]) / 20.0 + else: + text = line.strip() + duration = 5.0 + return text, duration + return None + + examples: List[Tuple[str, float]] = [] + if os.path.exists(txt_path): + try: + if txt_path.endswith(".txt"): + with cs.open(txt_path, "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + result = _parse_line(line) + if result is None: + continue + text, duration = result + examples.append((text, duration)) + elif txt_path.endswith(".json"): + with cs.open(txt_path, "r", encoding="utf-8") as f: + lines = json.load(f) + for key, value in lines.items(): + if "_raw_chn" in key or "GENERATE_PROMPT_FORMAT" in key: + continue + for line in value: + result = _parse_line(line) + if result is None: + continue + text, duration = result + examples.append((text, duration)) + print(f">>> Loaded {len(examples)} examples from {txt_path}") + except Exception as e: + print(f">>> Failed to load examples from {txt_path}: {e}") + else: + print(f">>> Examples file not found: {txt_path}") + + return examples + + +class T2MGradioUI: + def __init__(self, runtime: T2MRuntime, args: argparse.Namespace): + self.runtime = runtime + self.args = args + + # Check if rewrite is available: + # - prompt_engineering_host must be provided + # - disable_rewrite must not be set + print(f">>> args: {vars(args)}") + self.rewrite_available = ( + args.prompt_engineering_host is not None + and args.prompt_engineering_host.strip() != "" + and not args.disable_rewrite + ) + + self.all_example_data = {} + self._init_example_data() + + def _init_example_data(self): + for source_name, file_path in DATA_SOURCES.items(): + examples = load_examples_from_txt(file_path) + if examples: + self.all_example_data[source_name] = examples + else: + # provide default examples as fallback + self.all_example_data[source_name] = [ + ("Twist at the waist and punch across the body.", 3.0), + ("A person is running then takes big leap.", 3.0), + ("A person holds a railing and walks down a set of stairs.", 5.0), + ( + "A man performs a fluid and rhythmic hip-hop style dance, incorporating body waves, arm gestures, and side steps.", + 5.0, + ), + ] + print(f">>> Loaded data sources: {list(self.all_example_data.keys())}") + + def _get_header_text(self): + return HEADER_BASE_MD + + def _generate_random_seeds(self): + seeds = [random.randint(0, 999) for _ in range(4)] + return ",".join(map(str, seeds)) + + def _prompt_engineering( + self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True + ): + if not text.strip(): + return "", gr.update(interactive=False), gr.update() + + call_llm = enable_rewrite or enable_duration_est + if not call_llm: + print(f"\t>>> Using original duration and original text...") + predicted_duration = duration + rewritten_text = text + else: + print(f"\t>>> Using LLM to estimate duration/rewrite text...") + try: + predicted_duration, rewritten_text = self.runtime.rewrite_text_and_infer_time(text=text) + except Exception as e: + print(f"\t>>> Text rewriting/duration prediction failed: {e}") + return ( + f"❌ Text rewriting/duration prediction failed: {str(e)}", + gr.update(interactive=False), + gr.update(), + ) + if not enable_rewrite: + rewritten_text = text + if not enable_duration_est: + predicted_duration = duration + + return rewritten_text, gr.update(interactive=True), gr.update(value=predicted_duration) + + def _generate_motion( + self, + original_text: str, + rewritten_text: str, + seed_input: str, + duration: float, + cfg_scale: float, + ) -> Tuple[str, List[str]]: + # When rewrite is not available, use original_text directly + if not self.rewrite_available: + text_to_use = original_text.strip() + if not text_to_use: + return "Error: Input text is empty, please enter text first", [] + else: + text_to_use = rewritten_text.strip() + if not text_to_use: + return "Error: Rewritten text is empty, please rewrite the text first", [] + + try: + fbx_ok = getattr(self.runtime, "fbx_available", False) + req_format = "fbx" if fbx_ok else "dict" + html, fbx_files, _ = self.runtime.generate_motion( + text=text_to_use, + seeds_csv=seed_input, + duration=duration, + cfg_scale=cfg_scale, + output_format=req_format, + original_text=original_text, + output_dir=self.args.output_dir, + ) + iframe_html = f""" + + """ + return iframe_html, fbx_files + except Exception as e: + print(f"\t>>> Motion generation failed: {e}") + return ( + f"❌ Motion generation failed: {str(e)}\n\nPlease check the input parameters or try again later", + [], + ) + + def _get_example_choices(self): + """Get all example choices from all data sources""" + choices = ["Custom Input"] + for source_name in self.all_example_data: + example_data = self.all_example_data[source_name] + for text, _ in example_data: + display_text = f"{text[:50]}..." if len(text) > 50 else text + choices.append(display_text) + return choices + + def _on_example_select(self, selected_example): + """When selecting an example, the callback function""" + if selected_example == "Custom Input": + return "", self._generate_random_seeds(), gr.update() + else: + # find the corresponding example from all data sources + for source_name in self.all_example_data: + example_data = self.all_example_data[source_name] + for text, duration in example_data: + display_text = f"{text[:50]}..." if len(text) > 50 else text + if display_text == selected_example: + return text, self._generate_random_seeds(), gr.update(value=duration) + return "", self._generate_random_seeds(), gr.update() + + def build_ui(self): + with gr.Blocks(css=APP_CSS) as demo: + self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"]) + + with gr.Row(): + # Left control panel + with gr.Column(scale=2, elem_classes=["left-panel"]): + # Input textbox + self.text_input = gr.Textbox( + label="📝 Input Text", + placeholder="Enter text to generate motion, support Chinese and English text input.", + ) + # Rewritten textbox + self.rewritten_text = gr.Textbox( + label="✏️ Rewritten Text", + placeholder="Rewritten text will be displayed here, you can further edit", + interactive=True, + visible=False, + ) + # Duration slider + self.duration_slider = gr.Slider( + minimum=0.5, + maximum=12, + value=5.0, + step=0.1, + label="⏱️ Action Duration (seconds)", + info="Feel free to adjust the action duration", + ) + + # Execute buttons + with gr.Row(): + if self.rewrite_available: + self.rewrite_btn = gr.Button( + "🔄 Rewrite Text", + variant="secondary", + size="lg", + elem_classes=["rewrite-button"], + ) + else: + # Create a hidden/disabled placeholder button + self.rewrite_btn = gr.Button( + "🔄 Rewrite Text (Unavailable)", + variant="secondary", + size="lg", + elem_classes=["rewrite-button"], + interactive=False, + visible=False, + ) + + self.generate_btn = gr.Button( + "🚀 Generate Motion", + variant="primary", + size="lg", + elem_classes=["generate-button"], + interactive=not self.rewrite_available, # Enable directly if rewrite not available + ) + + if not self.rewrite_available: + gr.Markdown( + "> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly." + ) + + # Advanced settings + with gr.Accordion("🔧 Advanced Settings", open=False): + self._build_advanced_settings() + + # Example selection dropdown + self.example_dropdown = gr.Dropdown( + choices=self._get_example_choices(), + value="Custom Input", + label="📚 Test Examples", + info="Select a preset example or input your own text above", + interactive=True, + ) + + # Status message depends on whether rewrite is available + if self.rewrite_available: + status_msg = "Please click the [🔄 Rewrite Text] button to rewrite the text first" + else: + status_msg = "Enter your text and click [🚀 Generate Motion] directly." + + self.status_output = gr.Textbox( + label="📊 Status Information", + value=status_msg, + ) + + # FBX Download section + with gr.Row(visible=False) as self.fbx_download_row: + if getattr(self.runtime, "fbx_available", False): + self.fbx_files = gr.File( + label="📦 Download FBX Files", + file_count="multiple", + interactive=False, + ) + else: + self.fbx_files = gr.State([]) + + # Right display area + with gr.Column(scale=3): + self.output_display = gr.HTML(show_label=False, elem_classes=["flask-display"]) + + # Footer + gr.Markdown(FOOTER_MD, elem_classes=["footer"]) + + self._bind_events() + demo.load(fn=self._get_header_text, outputs=[self.header_md]) + return demo + + def _build_advanced_settings(self): + # Only show rewrite options if rewrite is available + if self.rewrite_available: + with gr.Group(): + gr.Markdown("### 🔄 Text Rewriting Options") + with gr.Row(): + self.enable_rewrite = gr.Checkbox( + label="Enable Text Rewriting", + value=True, + info="Automatically optimize text prompt to get better motion generation", + ) + + with gr.Group(): + gr.Markdown("### ⏱️ Duration Settings") + self.enable_duration_est = gr.Checkbox( + label="Enable Duration Estimation", + value=True, + info="Automatically estimate the duration of the motion", + ) + else: + # Create hidden placeholders with default values (disabled) + self.enable_rewrite = gr.Checkbox( + label="Enable Text Rewriting", + value=False, + visible=False, + ) + self.enable_duration_est = gr.Checkbox( + label="Enable Duration Estimation", + value=False, + visible=False, + ) + with gr.Group(): + gr.Markdown("### ⚠️ Prompt Engineering Unavailable") + gr.Markdown( + "Text rewriting and duration estimation are not available. " + "Your input text and duration will be used directly." + ) + + with gr.Group(): + gr.Markdown("### ⚙️ Generation Parameters") + with gr.Row(): + with gr.Column(scale=3): + self.seed_input = gr.Textbox( + label="🎯 Random Seed List (comma separated)", + value="0,1,2,3", + placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)", + info="Random seeds control the diversity of generated motions", + ) + with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]): + self.dice_btn = gr.Button( + "🎲 Lucky Button", + variant="secondary", + size="sm", + elem_classes=["dice-button"], + ) + + self.cfg_slider = gr.Slider( + minimum=1, + maximum=10, + value=5.0, + step=0.1, + label="⚙️ CFG Strength", + info="Text fidelity: higher = more faithful to the prompt", + ) + + def _bind_events(self): + # Generate random seeds + self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input]) + + # Bind example selection event + self.example_dropdown.change( + fn=self._on_example_select, + inputs=[self.example_dropdown], + outputs=[self.text_input, self.seed_input, self.duration_slider], + ) + + # Rewrite text logic (only bind when rewrite is available) + if self.rewrite_available: + self.rewrite_btn.click(fn=lambda: "Rewriting text, please wait...", outputs=[self.status_output]).then( + self._prompt_engineering, + inputs=[ + self.text_input, + self.duration_slider, + self.enable_rewrite, + self.enable_duration_est, + ], + outputs=[self.rewritten_text, self.generate_btn, self.duration_slider], + ).then( + fn=lambda: ( + gr.update(visible=True), + "Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]", + ), + outputs=[self.rewritten_text, self.status_output], + ) + + # Generate motion logic + self.generate_btn.click( + fn=lambda: "Generating motion, please wait... (It takes some extra time to start the renderer for the first generation)", + outputs=[self.status_output], + ).then( + self._generate_motion, + inputs=[ + self.text_input, + self.rewritten_text, + self.seed_input, + self.duration_slider, + self.cfg_slider, + ], + outputs=[self.output_display, self.fbx_files], + concurrency_limit=NUM_WORKERS, + ).then( + fn=lambda fbx_list: ( + ( + "🎉 Motion generation completed! You can view the motion visualization result on the right. FBX files are ready for download." + if fbx_list + else "🎉 Motion generation completed! You can view the motion visualization result on the right" + ), + gr.update(visible=bool(fbx_list)), + ), + inputs=[self.fbx_files], + outputs=[self.status_output, self.fbx_download_row], + ) + + # Reset logic - different behavior based on rewrite availability + if self.rewrite_available: + self.text_input.change( + fn=lambda: ( + gr.update(visible=False), + gr.update(interactive=False), + "Please click the [🔄 Rewrite Text] button to rewrite the text first", + ), + outputs=[self.rewritten_text, self.generate_btn, self.status_output], + ) + else: + # When rewrite is not available, enable generate button directly when text is entered + self.text_input.change( + fn=lambda text: ( + gr.update(visible=False), + gr.update(interactive=bool(text.strip())), + ( + "Ready to generate! Click [🚀 Generate Motion] to start." + if text.strip() + else "Enter your text and click [🚀 Generate Motion] directly." + ), + ), + inputs=[self.text_input], + outputs=[self.rewritten_text, self.generate_btn, self.status_output], + ) + # Only bind rewritten_text change when rewrite is available + if self.rewrite_available: + self.rewritten_text.change( + fn=lambda text: ( + gr.update(interactive=bool(text.strip())), + ( + "Rewritten text has been modified, you can click [🚀 Generate Motion]" + if text.strip() + else "Rewritten text cannot be empty, please enter valid text" + ), + ), + inputs=[self.rewritten_text], + outputs=[self.generate_btn, self.status_output], + ) + + +if __name__ == "__main__": + # parser = argparse.ArgumentParser(description="HY-Motion-1.0 Text-to-Motion Gradio") + # parser.add_argument("--model_path", type=str, required=True, help="Configuration file path") + # parser.add_argument( + # "--device_ids", type=str, default=None, help="GPU device ID list, separated by commas, e.g.: 0,1,2,3" + # ) + # parser.add_argument( + # "--prompt_engineering_host", + # type=str, + # default=None, + # help="Prompt engineering host address, for text rewriting and duration estimation", + # ) + # parser.add_argument("--output_dir", type=str, default="output/gradio", help="Output directory") + # parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server name") + # parser.add_argument("--port", type=int, default=8080, help="Server port") + # parser.add_argument("--disable_flask_server", action="store_true") + # parser.add_argument( + # "--disable_rewrite", + # action="store_true", + # help="Disable text rewriting and duration estimation, use input text and duration directly", + # ) + # args = parser.parse_args() + + final_model_path = './configs/base' + + class Args: + model_path = final_model_path + output_dir = "output/gradio" + prompt_engineering_host = os.environ.get("PROMPT_HOST", None) + disable_rewrite = False + + args = Args() + + # Check required files: + cfg = osp.join(args.model_path, "config.yml") + ckpt = osp.join(args.model_path, "latest.ckpt") + if not osp.exists(cfg): + raise FileNotFoundError(f">>> Configuration file not found: {cfg}") + + # Check checkpoint file - skip loading if not exists + skip_model_loading = False + if not os.path.exists(ckpt): + print(f">>> [WARNING] Checkpoint file not found: {ckpt}") + print(f">>> [WARNING] Model loading will be skipped. Motion generation will not be available.") + skip_model_loading = True + + # Initialize runtime + print(">>> Initializing T2MRuntime...") + runtime = T2MRuntime( + config_path=cfg, + ckpt_name=ckpt, + device_ids=None, + prompt_engineering_host=args.prompt_engineering_host, + skip_model_loading=skip_model_loading, + ) + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + ui = T2MGradioUI(runtime=runtime, args=args) + demo = ui.build_ui() + + demo.launch() diff --git a/hymotion/network/attention.py b/hymotion/network/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..25e376a328ad602e864c1b383d712dd0e901692a --- /dev/null +++ b/hymotion/network/attention.py @@ -0,0 +1,110 @@ +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor + +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func +except ImportError: + flash_attn = None + flash_attn_varlen_func = None + _flash_attn_forward = None + + +MEMORY_LAYOUT = { + "flash": (lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), lambda x: x), + "torch": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)), + "vanilla": (lambda x: x.transpose(1, 2), lambda x: x.transpose(1, 2)), +} + + +def attention( + q: Tensor, + k: Tensor, + v: Tensor, + mode: str = "flash", + drop_rate: float = 0.0, + attn_mask: Optional[Tensor] = None, + causal: bool = False, + cu_seqlens_q: Optional[Tensor] = None, + cu_seqlens_kv: Optional[Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + batch_size: int = 1, + training: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Perform QKV self attention. + + Args: + q (Tensor): Query tensor with shape [b, s, h, d], where h is the number of heads. + k (Tensor): Key tensor with shape [b, s1, h, d] + v (Tensor): Value tensor with shape [b, s1, h, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, h, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + Tensor: Output tensor after self attention with shape [b, s, hd] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + elif mode == "flash": + assert flash_attn_varlen_func is not None, "flash_attn is not installed or not supported" + x = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ) + # x with shape [(bxs), a, d] + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1.0 / math.sqrt(q.size(-1)) + b, a, s_q, _ = q.shape + s_k = k.size(2) + attn_bias = torch.zeros(b, a, s_q, s_k, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s_q, s_q, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(~temp_mask, float("-inf")) + attn_bias = attn_bias.to(q.dtype) + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(~attn_mask, float("-inf")) + else: + attn_bias = attn_bias + attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=training) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, h, d = x.shape + out = x.reshape(b, s, -1) + return out diff --git a/hymotion/network/bricks.py b/hymotion/network/bricks.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d507620c9e020bbd9fefcae5bdbc8d4935c054 --- /dev/null +++ b/hymotion/network/bricks.py @@ -0,0 +1,46 @@ +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +def get_activation_layer(act_type: str) -> Callable[[], nn.Module]: + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + + +def get_norm_layer(norm_type: Optional[str]): + if norm_type == "layer": + return nn.LayerNorm + elif norm_type == "rms": + return RMSNorm + elif norm_type == "none" or norm_type is None: + return nn.Identity + else: + raise ValueError(f"Unknown norm type: {norm_type}") + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6) -> None: + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: Tensor) -> Tensor: + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output diff --git a/hymotion/network/encoders.py b/hymotion/network/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..931b1fb85862930448dd83fbe296b6164c72dce0 --- /dev/null +++ b/hymotion/network/encoders.py @@ -0,0 +1,121 @@ +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from ..utils.misc import to_2tuple +from .bricks import get_activation_layer, get_norm_layer +from .modulate_layers import ModulateDiT, modulate + + +class MLP(nn.Module): + def __init__( + self, + in_dim: int, + feat_dim: int, + out_dim: Optional[int] = None, + act_type: str = "gelu", + norm_type: Optional[str] = None, + bias: bool = True, + drop: float = 0.0, + use_conv: bool = False, + ) -> None: + super().__init__() + out_dim = out_dim or in_dim + feat_dim = feat_dim or in_dim + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_dim, feat_dim, bias=bias[0] if isinstance(bias, (list, tuple)) else bias) + self.act = get_activation_layer(act_type)() + self.drop1 = nn.Dropout(drop_probs[0] if isinstance(drop_probs, (list, tuple)) else drop_probs) + self.norm = get_norm_layer(norm_type)(feat_dim) if norm_type else nn.Identity() + self.fc2 = linear_layer(feat_dim, out_dim, bias=bias[1] if isinstance(bias, (list, tuple)) else bias) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MLPEncoder(nn.Module): + def __init__(self, in_dim: int, feat_dim: int, num_layers: int, act_type: str = "silu") -> None: + super(MLPEncoder, self).__init__() + self.in_dim = in_dim + self.feat_dim = feat_dim + linears = [] + linears.append(nn.Linear(in_features=in_dim, out_features=self.feat_dim)) + for i in range(num_layers - 1): + linears.append(get_activation_layer(act_type)()) + linears.append(nn.Linear(self.feat_dim, self.feat_dim)) + self.linears = nn.Sequential(*linears) + + def forward(self, x: Tensor) -> Tensor: + return self.linears(x) + + +class FinalLayer(nn.Module): + def __init__(self, feat_dim: int, out_dim: int, act_type: str = "gelu", zero_init=False, **kwargs): + super().__init__() + + self.norm_final = nn.LayerNorm(feat_dim, elementwise_affine=False, eps=1e-6) + self.adaLN_modulation = ModulateDiT(feat_dim, factor=2, act_type=act_type) + self.linear = nn.Linear(feat_dim, out_dim, bias=True) + if zero_init: + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: Tensor, adapter: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(adapter).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class TimestepEmbeddingEncoder(nn.Module): + def __init__( + self, + embedding_dim: int, + feat_dim: int, + act_type: str = "silu", + time_factor: float = 1.0, + ) -> None: + super(TimestepEmbeddingEncoder, self).__init__() + + self.embedding_dim = embedding_dim + self.feat_dim = feat_dim + self.time_factor = time_factor + blocks = [ + nn.Linear(embedding_dim, self.feat_dim), + get_activation_layer(act_type)(), + nn.Linear(self.feat_dim, self.feat_dim), + ] + self.blocks = nn.Sequential(*blocks) + + def forward(self, t: Tensor) -> Tensor: + x = self.blocks(self.sinusodial_embedding(t, self.embedding_dim, time_factor=self.time_factor)).unsqueeze(1) + return x + + @staticmethod + def sinusodial_embedding( + timesteps: Tensor, embedding_dim: int, temperature: float = 10000.0, time_factor: float = 1.0 + ) -> Tensor: + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + timesteps = timesteps * time_factor + half = embedding_dim // 2 + freqs = torch.exp( + -torch.log(torch.tensor(temperature)) * torch.arange(start=0, end=half, dtype=torch.float) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if embedding_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/hymotion/network/hymotion_mmdit.py b/hymotion/network/hymotion_mmdit.py new file mode 100644 index 0000000000000000000000000000000000000000..79c56417f568767d9644916df78ca4415947e49d --- /dev/null +++ b/hymotion/network/hymotion_mmdit.py @@ -0,0 +1,636 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor + +from ..utils.loaders import load_object +from ..utils.type_converter import get_module_device +from .attention import attention +from .bricks import get_activation_layer, get_norm_layer +from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder +from .modulate_layers import ModulateDiT, apply_gate, modulate +from .positional_encoding import RotaryEmbedding + + +class MMBaseBlock(nn.Module): + def __init__( + self, + feat_dim: int, + num_heads: int, + mlp_ratio: float, + dropout: float, + positional_encoding_cfg: dict, + apply_rope_to_single_branch: bool, + ): + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.dropout = dropout + + assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}" + self.head_dim = self.feat_dim // num_heads + + self.mlp_hidden_dim = int(self.feat_dim * mlp_ratio) + + self._positional_encoding_cfg = positional_encoding_cfg.copy() + self.rotary_emb = RotaryEmbedding(num_feats=self.head_dim, **self._positional_encoding_cfg) + self.apply_rope_to_single_branch = apply_rope_to_single_branch + + +class MMDoubleStreamBlock(MMBaseBlock): + def __init__( + self, + feat_dim: int, + num_heads: int, + mlp_ratio: float, + dropout: float, + mlp_act_type: str, + qk_norm_type: Optional[str] = None, + qkv_bias: bool = False, + positional_encoding_cfg: dict = { + "max_seq_len": 5000, + "use_real": True, + }, + apply_rope_to_single_branch: bool = True, + ): + super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch) + + self.motion_mod = ModulateDiT( + self.feat_dim, + factor=6, + act_type="silu", + ) + self.motion_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6) + + motion_qkv_out_dim = self.feat_dim * 3 + self.motion_qkv = nn.Linear(self.feat_dim, motion_qkv_out_dim, bias=qkv_bias) + + self.motion_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.motion_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.motion_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias) + self.motion_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6) + self.motion_mlp = MLP( + self.feat_dim, + self.mlp_hidden_dim, + act_type=mlp_act_type, + bias=True, + ) + + self.text_mod = ModulateDiT( + self.feat_dim, + factor=6, + act_type="silu", + ) + self.text_norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6) + + text_qkv_out_dim = self.feat_dim * 3 + self.text_qkv = nn.Linear(self.feat_dim, text_qkv_out_dim, bias=qkv_bias) + + self.text_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.text_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.text_out_proj = nn.Linear(self.feat_dim, self.feat_dim, bias=qkv_bias) + self.text_norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6) + self.text_mlp = MLP( + self.feat_dim, + self.mlp_hidden_dim, + act_type=mlp_act_type, + bias=True, + ) + + def forward( + self, + motion_feat: Tensor, + text_feat: Tensor, + adapter: Tensor, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + ( + motion_shift_msa, + motion_scale_msa, + motion_gate_msa, + motion_shift_mlp, + motion_scale_mlp, + motion_gate_mlp, + ) = self.motion_mod(adapter).chunk(6, dim=-1) + ( + text_shift_msa, + text_scale_msa, + text_gate_msa, + text_shift_mlp, + text_scale_mlp, + text_gate_mlp, + ) = self.text_mod( + adapter + ).chunk(6, dim=-1) + + motion_modulated = self.motion_norm1(motion_feat) + motion_modulated = modulate(motion_modulated, shift=motion_shift_msa, scale=motion_scale_msa) + motion_qkv = self.motion_qkv(motion_modulated) + + motion_q, motion_k, motion_v = rearrange(motion_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + motion_q = self.motion_q_norm(motion_q).to(motion_v) + motion_k = self.motion_k_norm(motion_k).to(motion_v) + + if self.apply_rope_to_single_branch: + # NOTE: we don't apply RoPE to text_branch_two here + motion_q, motion_k = self.rotary_emb.apply_rotary_emb(motion_q, motion_k) + + text_modulated = self.text_norm1(text_feat) + text_modulated = modulate(text_modulated, shift=text_shift_msa, scale=text_scale_msa) + text_qkv = self.text_qkv(text_modulated) + + text_q, text_k, text_v = rearrange( + text_qkv, + "B L (K H D) -> K B L H D", + K=3, + H=self.num_heads, + ) + text_q = self.text_q_norm(text_q).to(text_v) + text_k = self.text_k_norm(text_k).to(text_v) + + q = torch.cat((motion_q, text_q), dim=1) + k = torch.cat((motion_k, text_k), dim=1) + v = torch.cat((motion_v, text_v), dim=1) + + if not self.apply_rope_to_single_branch: + q, k = self.rotary_emb.apply_rotary_emb(q, k) + + bsz, total_len, _, _ = q.shape + motion_len = motion_feat.shape[1] + text_len = text_feat.shape[1] + dropout_p = 0.0 if not self.training else self.dropout + + attn_output = attention( + q, + k, + v, + mode="torch", + drop_rate=dropout_p, + attn_mask=attn_mask, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + batch_size=bsz, + training=self.training, + ) + + motion_attn_output, text_attn_output = ( + attn_output[:, :motion_len, ...], + attn_output[:, motion_len:, ...], + ) + + motion_feat = motion_feat + apply_gate(self.motion_out_proj(motion_attn_output), gate=motion_gate_msa) + motion_feat = motion_feat + apply_gate( + self.motion_mlp( + modulate( + self.motion_norm2(motion_feat), + shift=motion_shift_mlp, + scale=motion_scale_mlp, + ) + ), + gate=motion_gate_mlp, + ) + + text_feat = text_feat + apply_gate(self.text_out_proj(text_attn_output), gate=text_gate_msa) + text_feat = text_feat + apply_gate( + self.text_mlp( + modulate( + self.text_norm2(text_feat), + shift=text_shift_mlp, + scale=text_scale_mlp, + ) + ), + gate=text_gate_mlp, + ) + + return motion_feat, text_feat + + +class MMSingleStreamBlock(MMBaseBlock): + def __init__( + self, + feat_dim: int, + num_heads: int, + mlp_ratio: float, + dropout: float, + mlp_act_type: str, + qk_norm_type: Optional[str] = None, + qkv_bias: bool = False, + positional_encoding_cfg: dict = { + "max_seq_len": 5000, + "use_real": True, + }, + apply_rope_to_single_branch: bool = True, + ): + super().__init__(feat_dim, num_heads, mlp_ratio, dropout, positional_encoding_cfg, apply_rope_to_single_branch) + + self.modulation = ModulateDiT(self.feat_dim, factor=3, act_type="silu") + self.norm = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=False, eps=1e-6) + + # qkv and mlp_in + qkv_factor = 3 + self.linear1 = nn.Linear(self.feat_dim, self.feat_dim * qkv_factor + self.mlp_hidden_dim, bias=qkv_bias) + # proj and mlp_out + self.linear2 = nn.Linear(self.feat_dim + self.mlp_hidden_dim, self.feat_dim, bias=qkv_bias) + + self.q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + + self.mlp_act = get_activation_layer(mlp_act_type)() + + def forward( + self, + x: Tensor, + split_len: int, + adapter: Tensor, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + ( + shift_msa, + scale_msa, + gate_msa, + ) = self.modulation( + adapter + ).chunk(3, dim=-1) + x_modulated = modulate(self.norm(x), shift_msa, scale_msa) + + qkv, mlp_hidden = torch.split(self.linear1(x_modulated), [3 * self.feat_dim, self.mlp_hidden_dim], dim=-1) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + q1, q2 = q[:, :split_len, ...], q[:, split_len:, ...] + k1, k2 = k[:, :split_len, ...], k[:, split_len:, ...] + # apply rotary position embedding + if self.apply_rope_to_single_branch: + q1, k1 = self.rotary_emb.apply_rotary_emb(q1, k1) + q = torch.cat((q1, q2), dim=1) + k = torch.cat((k1, k2), dim=1) + if not self.apply_rope_to_single_branch: + q, k = self.rotary_emb.apply_rotary_emb(q, k) + + bsz, total_len = x_modulated.shape[:2] + dropout_p = 0.0 if not self.training else self.dropout + + attn_output = attention( + q, + k, + v, + mode="torch", + drop_rate=dropout_p, + attn_mask=attn_mask, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + batch_size=bsz, + training=self.training, + ) + output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp_hidden)), 2)) + + return x + apply_gate(output, gate=gate_msa) + + +class HunyuanMotionMMDiT(nn.Module): + def __init__( + self, + input_dim: int, + feat_dim: int, + output_dim: Optional[int] = None, + ctxt_input_dim: int = 4096, + vtxt_input_dim: int = 256, + text_refiner_module: str = "hymotion/network/token_refiner.SingleTokenRefiner", + text_refiner_cfg: dict = { + "num_layers": 2, + }, + num_layers: int = 12, + num_heads: int = 16, + mlp_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + norm_type: str = "layer", + qk_norm_type: str = "rms", + qkv_bias: bool = True, + dropout: float = 0.0, + final_layer_module: str = "hymotion/network/encoders.FinalLayer", + final_layer_cfg: dict = { + "act_type": "silu", + }, + mask_mode: Optional[str] = None, + apply_rope_to_single_branch: bool = True, + insert_start_token: bool = False, + with_long_skip_connection: bool = False, + time_factor: float = 1.0, + narrowband_length: float = 2.0, + **kwargs, + ): + super().__init__() + self.motion_input_dim = input_dim + self.ctxt_input_dim = ctxt_input_dim + self.vtxt_input_dim = vtxt_input_dim + self.feat_dim = feat_dim + self.output_dim = output_dim or input_dim + self.mask_mode = mask_mode + self.insert_start_token = insert_start_token + self.time_factor = time_factor + self.narrowband_length = narrowband_length * 30.0 + if self.insert_start_token: + self.start_token = nn.Parameter(torch.randn(1, feat_dim)) + self.with_long_skip_connection = with_long_skip_connection + if self.with_long_skip_connection: + from .encoders import FinalLayer + + self.long_skip_net = FinalLayer(feat_dim=feat_dim, out_dim=feat_dim, act_type="silu") + + self.input_encoder = nn.Linear(in_features=input_dim, out_features=feat_dim) + self.ctxt_encoder = nn.Linear(in_features=ctxt_input_dim, out_features=feat_dim) + self.vtxt_encoder = MLPEncoder(in_dim=vtxt_input_dim, feat_dim=feat_dim, num_layers=2, act_type="silu") + self.timestep_encoder = TimestepEmbeddingEncoder( + embedding_dim=feat_dim, + feat_dim=feat_dim, + time_factor=time_factor, + ) + + if text_refiner_module != "" and text_refiner_module is not None: + text_refiner_cfg.update(input_dim=feat_dim, feat_dim=feat_dim, num_heads=num_heads) + self._text_refiner_cfg = text_refiner_cfg.copy() + self.text_refiner = load_object(text_refiner_module, text_refiner_cfg) + + self.num_layers = num_layers + assert num_layers % 3 == 0, f"num_layers must be divisible by 3, but got {num_layers}" + self.mm_double_blocks_layers = int(num_layers // 3) + self.mm_single_blocks_layers = int(num_layers - num_layers // 3) + + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + feat_dim=feat_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + mlp_act_type=mlp_act_type, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + apply_rope_to_single_branch=apply_rope_to_single_branch, + ) + for _ in range(self.mm_double_blocks_layers) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + feat_dim=feat_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + mlp_act_type=mlp_act_type, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + apply_rope_to_single_branch=apply_rope_to_single_branch, + ) + for _ in range(self.mm_single_blocks_layers) + ] + ) + + final_layer_cfg.update(feat_dim=feat_dim, out_dim=self.output_dim) + self._final_layer_cfg = final_layer_cfg.copy() + self.final_layer = load_object(final_layer_module, final_layer_cfg) + + def forward( + self, + x: Tensor, + ctxt_input: Tensor, + vtxt_input: Tensor, + timesteps: Tensor, + x_mask_temporal: Tensor, + ctxt_mask_temporal: Tensor, + **kwargs, + ) -> Tensor: + device = get_module_device(self) + + motion_feat = self.input_encoder(x) + if self.with_long_skip_connection: + origin_feat = motion_feat + if self.insert_start_token: + # (B, 1, D) + (B, L, D) -> (B, L+1, D) + start_token = self.start_token[None].repeat(motion_feat.shape[0], 1, 1) + motion_feat = torch.cat((start_token, motion_feat), dim=1) + x_mask_temporal = torch.cat( + [ + torch.ones_like(x_mask_temporal[:, :1], dtype=torch.bool), + x_mask_temporal, + ], + dim=1, + ) + + timestep_feat = self.timestep_encoder(timesteps) + vtxt_feat = self.vtxt_encoder(vtxt_input.float()) + adapter = timestep_feat + vtxt_feat + + motion_key_padding_mask = self._canonical_mask(x_mask_temporal).to(device) + ctxt_key_padding_mask = self._canonical_mask(ctxt_mask_temporal).to(device) + seq_key_padding_mask = torch.cat((motion_key_padding_mask, ctxt_key_padding_mask), dim=1) + if self.mask_mode is None: + seq_mask = None + elif self.mask_mode == "causal": + motion_len = motion_feat.shape[1] + seq_mask = torch.triu( + torch.full((motion_len, motion_len), float("-inf"), device=device), + diagonal=1, + ) + elif self.mask_mode == "narrowband": + window = int(round(self.narrowband_length)) + motion_len = motion_feat.shape[1] + idx = torch.arange(motion_len, device=device) + dist = (idx[None, :] - idx[:, None]).abs() + band = dist <= window + seq_mask = torch.full((motion_len, motion_len), float("-inf"), device=device) + seq_mask = seq_mask.masked_fill(band, 0.0) + else: + raise ValueError(f"Unsupported mask mode: {self.mask_mode}") + + ctxt_feat = self.ctxt_encoder(ctxt_input.float()) + if hasattr(self, "text_refiner"): + ctxt_feat = self.text_refiner(x=ctxt_feat, t=timesteps, mask=(ctxt_key_padding_mask == 0).to(device)) + + # precompute shared attention masks (broadcastable over heads) + bsz = x.shape[0] + motion_len = motion_feat.shape[1] + text_len = ctxt_feat.shape[1] + total_len = motion_len + text_len + mask_dtype = motion_feat.dtype + attn_mask_double = self._build_dmm_attn_mask_shared( + bsz=bsz, + motion_len=motion_len, + text_len=text_len, + dtype=mask_dtype, + key_padding_mask=seq_key_padding_mask, + attn_mask=seq_mask, + device=device, + ) + for i_layer, mod in enumerate(self.double_blocks): + motion_feat, ctxt_feat = mod( + motion_feat=motion_feat, + text_feat=ctxt_feat, + adapter=adapter, + attn_mask=attn_mask_double, + ) + + # precompute shared attention masks for single stream blocks too + split_len = motion_feat.shape[1] + x = torch.cat((motion_feat, ctxt_feat), 1) + attn_mask_single = self._build_smm_attn_mask_shared( + bsz=bsz, + split_len=split_len, + total_len=total_len, + dtype=mask_dtype, + key_padding_mask=seq_key_padding_mask, + attn_mask=seq_mask, + device=device, + ) + for i_layer, mod in enumerate(self.single_blocks): + x = mod( + x=x, + split_len=split_len, + adapter=adapter, + attn_mask=attn_mask_single, + ) + + x = x[:, :split_len, ...] + if self.insert_start_token: + x = x[:, 1:, ...] + + if self.with_long_skip_connection: + # long skip only consider timestep_feat + x = self.long_skip_net(origin_feat, timestep_feat) + x + + predicted_res = self.final_layer(x, adapter) + return predicted_res + + @staticmethod + def _canonical_mask(input_mask: Tensor) -> Tensor: + if input_mask.ndim == 1: + input_mask = input_mask.unsqueeze(1) + key_padding_mask = torch.where( + input_mask, + torch.zeros_like(input_mask, dtype=torch.float), + torch.full_like(input_mask, float("-inf"), dtype=torch.float), + ) + return key_padding_mask + + def _build_dmm_attn_mask_shared( + self, + bsz: int, + motion_len: int, + text_len: int, + dtype: torch.dtype, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + device: torch.device, + ) -> Tensor: + """ + NOTE: + motion_k text_k + motion_q [M→M] [M→T] + text_q [T→M] [T→T] + only [M→M] contains given mask + """ + total_len = motion_len + text_len + base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device) + if attn_mask is not None: + if attn_mask.dim() != 2 or attn_mask.shape != (motion_len, motion_len): + raise RuntimeError( + f"attn_mask should be 2D with shape {(motion_len, motion_len)}, got {attn_mask.shape}" + ) + base[:, :, :motion_len, :motion_len] += attn_mask.view(1, 1, motion_len, motion_len) + if key_padding_mask is not None: + mask_total_len = key_padding_mask.shape[1] + if mask_total_len == motion_len: + pad = torch.zeros((bsz, text_len), dtype=key_padding_mask.dtype, device=device) + key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1) + base = base + key_padding_mask.view(bsz, 1, 1, total_len) + # disable T→M + base[:, :, motion_len:, :motion_len] = float("-inf") + return base + + def _build_smm_attn_mask_shared( + self, + bsz: int, + split_len: int, + total_len: int, + dtype: torch.dtype, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + device: torch.device, + ) -> Tensor: + """ + NOTE: + motion_k text_k + motion_q [M→M] [M→T] + text_q [T→M] [T→T] + only [M→M] contains given mask + """ + base = torch.zeros((bsz, 1, total_len, total_len), dtype=dtype, device=device) + if attn_mask is not None: + if attn_mask.dim() != 2 or attn_mask.shape != (split_len, split_len): + raise RuntimeError(f"attn_mask should be 2D with shape {(split_len, split_len)}, got {attn_mask.shape}") + base[:, :, :split_len, :split_len] += attn_mask.view(1, 1, split_len, split_len) + if key_padding_mask is not None: + mask_total_len = key_padding_mask.shape[1] + if mask_total_len == split_len: + pad = torch.zeros( + (bsz, total_len - split_len), + dtype=key_padding_mask.dtype, + device=device, + ) + key_padding_mask = torch.cat((key_padding_mask, pad), dim=-1) + base = base + key_padding_mask.view(bsz, 1, 1, total_len) + # disable T→M + base[:, :, split_len:, :split_len] = float("-inf") + return base + + +if __name__ == "__main__": + # python -m hymotion.network.hymotion_mmdit + + from configs._base_.model_network_base import MOTION_MODEL_CONFIG # pyright: ignore + + network_module_cfg = MOTION_MODEL_CONFIG["1.04B_narrowband"]["network_module_args"] + network_module_cfg = dict(network_module_cfg) # convert to normal dict + + bsz, seq_len, text_seq_len, input_dim = 1, 360, 128, 201 + network_module_cfg["input_dim"] = input_dim + MMDiT = HunyuanMotionMMDiT(**network_module_cfg) + + x = torch.randn(bsz, seq_len, input_dim) + ctxt_condition = torch.randn(bsz, text_seq_len, 4096) + vtxt_condition = torch.randn(bsz, 1, 768) + timesteps = torch.randint(0, 1000, (bsz,)) + length = torch.arange(seq_len).unsqueeze(0).repeat(bsz, 1) + ctxt_length = torch.arange(text_seq_len).unsqueeze(0).repeat(bsz, 1) + x_mask_temporal = length < 100 + ctxt_mask_temporal = ctxt_length < 50 + x = MMDiT( + x=x, + ctxt_input=ctxt_condition, + vtxt_input=vtxt_condition, + timesteps=timesteps, + x_mask_temporal=x_mask_temporal, + ctxt_mask_temporal=ctxt_mask_temporal, + ) + assert x.shape == ( + bsz, + seq_len, + input_dim, + ), f"unexpected output shape: {x.shape}, which should be ({bsz}, {seq_len}, {input_dim})" + print(x.shape) diff --git a/hymotion/network/modulate_layers.py b/hymotion/network/modulate_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..092ce07c4773e4d5d5582d468e031eaf5037ab65 --- /dev/null +++ b/hymotion/network/modulate_layers.py @@ -0,0 +1,49 @@ +from typing import Optional + +import torch.nn as nn +from torch import Tensor + +from .bricks import get_activation_layer + + +class ModulateDiT(nn.Module): + def __init__(self, feat_dim: int, factor: int, act_type: str = "silu"): + super().__init__() + self.act = get_activation_layer(act_type)() + self.linear = nn.Linear(feat_dim, factor * feat_dim, bias=True) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(self.act(x)) + + +def modulate(x: Tensor, shift: Optional[Tensor] = None, scale: Optional[Tensor] = None) -> Tensor: + if shift is not None and scale is not None: + assert len(x.shape) == len(shift.shape) == len(scale.shape), ( + "x, shift, scale must have the same number of dimensions, " + f"but got x.shape: {x.shape}, " + f"shift.shape: {shift.shape} " + f"and scale.shape: {scale.shape}" + ) + if shift is not None and scale is not None: + return x * (1 + scale) + shift + elif shift is not None: + return x + shift + elif scale is not None: + return x * (1 + scale) + else: + return x + + +def apply_gate(x: Tensor, gate: Optional[Tensor] = None, tanh: bool = False) -> Tensor: + if gate is not None: + assert len(x.shape) == len( + gate.shape + ), f"x, gate must have the same number of dimensions, but got {x.shape} and {gate.shape}" + if gate is None: + return x + if tanh: + return x * gate.tanh() + else: + return x * gate diff --git a/hymotion/network/positional_encoding.py b/hymotion/network/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0e2acd225829192d06c1610b021c63f286bcf6 --- /dev/null +++ b/hymotion/network/positional_encoding.py @@ -0,0 +1,174 @@ +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + + +class RotaryEmbedding(nn.Module): + def __init__( + self, + num_feats: int, + max_seq_len: Union[Tensor, int], + temperature: int = 10000, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + ) -> None: + super(RotaryEmbedding, self).__init__() + assert num_feats % 2 == 0, "num_feats (head_dim) must be even for RoPE." + self.num_feats = num_feats + self.max_seq_len = max_seq_len + self.temperature = temperature + self.use_real = use_real + self.theta_rescale_factor = theta_rescale_factor + self.interpolation_factor = interpolation_factor + + if isinstance(max_seq_len, int): + max_seq_len = torch.arange(max_seq_len).float() + + if theta_rescale_factor != 1.0: + temperature *= theta_rescale_factor ** (self.num_feats / (self.num_feats - 2)) + dim_t = torch.arange(0, self.num_feats, 2, dtype=torch.float32) + freqs = 1.0 / (temperature ** (2 * torch.div(dim_t, 2, rounding_mode="trunc") / self.num_feats)) # [D/2] + freqs = torch.outer(max_seq_len.float() * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + self.freqs_cis = (freqs_cos, freqs_sin) + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2] + self.freqs_cis = freqs_cis + + def reshape_for_broadcast( + self, freqs_cis: Union[Tensor, Tuple[Tensor, Tensor]], x: Tensor + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + assert ( + freqs_cis[0].shape[-1] == x.shape[-1] + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape} on the head_dim dimension" + assert freqs_cis[0].shape[0] >= x.shape[1], ( + f"freqs_cis shape {freqs_cis[0].shape} should be larger than or equal to " + f"x shape {x.shape} on the time dimension" + ) + shape = [] + for i, d in enumerate(x.shape): + if i == 1: + shape.append(-1) + elif i == ndim - 1: + shape.append(d) + else: + shape.append(1) + return ( + freqs_cis[0].view(*shape)[:, : x.shape[1], ...], + freqs_cis[1].view(*shape)[:, : x.shape[1], ...], + ) + else: + # freqs_cis: values in complex space + assert ( + freqs_cis.shape[-1] == x.shape[-1] + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape} on the head_dim dimension" + assert freqs_cis.shape[0] >= x.shape[1], ( + f"freqs_cis shape {freqs_cis.shape} should be larger than or equal to " + f"x shape {x.shape} on the time dimension" + ) + shape = [] + for i, d in enumerate(x.shape): + if i == 1: + shape.append(-1) + elif i == ndim - 1: + shape.append(d) + else: + shape.append(1) + return freqs_cis.view(*shape)[:, : x.shape[1], ...] + + def rotate_half(self, x: Tensor) -> Tensor: + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + def apply_rotary_emb(self, xq: Tensor, xk: Tensor) -> Tuple[Tensor, Tensor]: + xk_out = None + if isinstance(self.freqs_cis, tuple): + cos, sin = self.reshape_for_broadcast(self.freqs_cis, xq) # [B, L, H, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = self.reshape_for_broadcast(self.freqs_cis, xq_) + # Handle device transfer based on return type + if isinstance(freqs_cis, tuple): + freqs_cis = (freqs_cis[0].to(xq.device), freqs_cis[1].to(xq.device)) + else: + freqs_cis = freqs_cis.to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"max_seq_len={self.max_seq_len}, " + repr_str += f"temperature={self.temperature}, " + repr_str += f"use_real={self.use_real}, " + repr_str += f"theta_rescale_factor={self.theta_rescale_factor}, " + repr_str += f"interpolation_factor={self.interpolation_factor})" + return repr_str + + +class PositionalEncoding(nn.Module): + def __init__(self, num_feats: int, dropout: float = 0.1, max_len: int = 5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, num_feats) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, num_feats, 2).float() * (-np.log(10000.0) / num_feats)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) # shape of [1, L, D] + self.register_buffer("pe", pe) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.pe[:, : x.shape[1], :] # shape of [B, L, D] + return self.dropout(x) + + +if __name__ == "__main__": + # python -m hymotion.network.positional_encoding + num_feats = 32 + rope = RotaryEmbedding(num_feats=num_feats, max_seq_len=5000, use_real=True) + x = torch.ones(1, 360, 1, num_feats) + text = torch.ones(1, 256, 1, num_feats) + q1, k1 = x.clone(), x.clone() + q2, k2 = text.clone(), text.clone() + print(x.shape) + # q1, k1 = rope.apply_rotary_emb(q1, k1) + # q2, k2 = rope.apply_rotary_emb(q2, k2) + q = torch.cat([q1, q2], dim=1) + k = torch.cat([k1, k2], dim=1) + q, k = rope.apply_rotary_emb(q, k) + q, k = q[0, :, 0, :], k[0, :, 0, :] + attn = (q[:, None] * k[None, :]).sum(dim=-1) + # softmax + # attn = torch.softmax(attn, dim=-1) + attn = attn.cpu().numpy() + + import matplotlib.pyplot as plt + + plt.imshow(attn, cmap="hot") + plt.colorbar() + plt.savefig("attn.png") + breakpoint() diff --git a/hymotion/network/text_encoders/model_constants.py b/hymotion/network/text_encoders/model_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..98bb280ed2e3db072ddc9b01c2c08fa21a8b13ef --- /dev/null +++ b/hymotion/network/text_encoders/model_constants.py @@ -0,0 +1,8 @@ +__all__ = [ + "PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION", +] + + +PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION = """ + Summarize human motion only from the user text for representation: action categories, key body-part movements, order/transitions, trajectory/direction, posture; include style/emotion/speed only if present. Explicitly capture laterality (left/right) when mentioned; do not guess. If multiple actions are described, indicate the count of distinct actions (e.g., actions=3) and their order. Do not invent missing info. Keep one concise paragraph. +""" diff --git a/hymotion/network/text_encoders/text_encoder.py b/hymotion/network/text_encoders/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4748aafb98925f9dd466b4698fd31d9be1568462 --- /dev/null +++ b/hymotion/network/text_encoders/text_encoder.py @@ -0,0 +1,293 @@ +import os +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + CLIPTextModel, + CLIPTokenizer, +) + +from ...utils.type_converter import get_module_device +from .model_constants import PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION + +USE_HF_MODELS = os.environ.get("USE_HF_MODELS", "0") == "1" + +if USE_HF_MODELS: + QWEN_PATH = "Qwen/Qwen3-8B" + CLIP_PATH = "openai/clip-vit-large-patch14" +else: + QWEN_PATH = "ckpts/Qwen3-8B" + CLIP_PATH = "ckpts/clip-vit-large-patch14" + +LLM_ENCODER_LAYOUT = { + "qwen3": { + "module_path": QWEN_PATH, + "template": [ + {"role": "system", "content": f"{PROMPT_TEMPLATE_ENCODE_HUMAN_MOTION}"}, + {"role": "user", "content": "{}"}, + ], + "crop_start": 0, + "tokenizer_class": AutoTokenizer, + "text_encoder_class": AutoModelForCausalLM, + }, +} + +SENTENCE_EMB_LAYOUT = { + "clipl": { + "module_path": CLIP_PATH, + "tokenizer_class": CLIPTokenizer, + "text_encoder_class": CLIPTextModel, + "pooling_mode": "pooler_output", + "max_length": 77, + }, +} + + +class HYTextModel(nn.Module): + def __init__( + self, + llm_type: Optional[str] = "qwen3", + max_length_llm: int = 512, + sentence_emb_type: Optional[str] = "clipl", + max_length_sentence_emb: int = 77, + enable_llm_padding: bool = True, + ) -> None: + super().__init__() + self.text_encoder_type = "hy_text_model" + + self.sentence_emb_type = sentence_emb_type + self.sentence_emb_text_encoder = None + self.sentence_emb_tokenizer = None + self.vtxt_dim = 0 + if sentence_emb_type is not None: + assert sentence_emb_type in SENTENCE_EMB_LAYOUT, f"Unsupported sentence embedding type: {sentence_emb_type}" + self.max_length_sentence_emb = max_length_sentence_emb or SENTENCE_EMB_LAYOUT[sentence_emb_type].get( + "max_length", 77 + ) + self._sentence_emb_pooling_mode = SENTENCE_EMB_LAYOUT[sentence_emb_type].get( + "pooling_mode", "pooler_output" + ) + tokenizer_kwargs = SENTENCE_EMB_LAYOUT[sentence_emb_type].get("tokenizer_kwargs", {}) + + self.sentence_emb_tokenizer = SENTENCE_EMB_LAYOUT[sentence_emb_type]["tokenizer_class"].from_pretrained( + SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"], + max_length=self.max_length_sentence_emb, + **tokenizer_kwargs, + ) + self.sentence_emb_text_encoder = SENTENCE_EMB_LAYOUT[sentence_emb_type][ + "text_encoder_class" + ].from_pretrained(SENTENCE_EMB_LAYOUT[sentence_emb_type]["module_path"]) + self.sentence_emb_text_encoder = self.sentence_emb_text_encoder.eval().requires_grad_(False) + self.vtxt_dim = self.sentence_emb_text_encoder.config.hidden_size + + self.llm_type = llm_type + self.llm_text_encoder = None + self.llm_tokenizer = None + self.ctxt_dim = 0 + self.crop_start = 0 + self.max_length_llm = max_length_llm + if llm_type is not None: + assert llm_type in LLM_ENCODER_LAYOUT, f"Unsupported LLM type: {llm_type}" + self._orig_max_length_llm = max_length_llm + self.enable_llm_padding = enable_llm_padding + self.llm_tokenizer = LLM_ENCODER_LAYOUT[llm_type]["tokenizer_class"].from_pretrained( + LLM_ENCODER_LAYOUT[llm_type]["module_path"], + padding_side="right", + ) + self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained( + LLM_ENCODER_LAYOUT[llm_type]["module_path"], low_cpu_mem_usage=True + ) + self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False) + self.ctxt_dim = self.llm_text_encoder.config.hidden_size + + self.crop_start = self._compute_crop_start() + self.max_length_llm = self._orig_max_length_llm + self.crop_start + + @torch.no_grad() + def encode_llm(self, text: List[str]) -> Tuple[Tensor, Tensor]: + if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None: + raise ValueError("LLM model not initialized") + + device = get_module_device(self) + llm_text = [ + ( + self.llm_tokenizer.apply_chat_template( + self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"]), + tokenize=False, + add_generation_prompt=False, + enable_thinking=False, + ) + if self.llm_type == "qwen3" + else self.apply_text_to_template(one_text, LLM_ENCODER_LAYOUT[self.llm_type]["template"]) + ) + for one_text in text + ] + padding_mode = "max_length" if self.enable_llm_padding else False + llm_batch_encoding = self.llm_tokenizer( + llm_text, + return_length=False, + return_overflowing_tokens=False, + truncation=True, + return_attention_mask=True, + max_length=self.max_length_llm, # = crop_start + _orig_max_length_llm + padding=padding_mode, + return_tensors="pt", + ) + llm_outputs = ( + self.llm_text_encoder( + input_ids=llm_batch_encoding["input_ids"].to(device), + attention_mask=llm_batch_encoding["attention_mask"].to(device), + output_hidden_states=True, + ) + if self.llm_type == "qwen3" + else self.llm_text_encoder( + input_ids=llm_batch_encoding["input_ids"].to(device), + attention_mask=llm_batch_encoding["attention_mask"].to(device), + ) + ) + if self.llm_type == "qwen3": + ctxt_raw = llm_outputs.hidden_states[-1] + else: + ctxt_raw = llm_outputs.last_hidden_state + + start = self.crop_start + end = start + self._orig_max_length_llm + ctxt_raw = ctxt_raw[:, start:end].contiguous() # [bs, _orig_max_length_llm, hidden] + ctxt_length = (llm_batch_encoding["attention_mask"].sum(dim=-1).to(device) - start).clamp( + min=0, max=self._orig_max_length_llm + ) + return ctxt_raw, ctxt_length + + @torch.no_grad() + def encode_sentence_emb(self, text: List[str]) -> Tensor: + if ( + self.sentence_emb_type is None + or self.sentence_emb_text_encoder is None + or self.sentence_emb_tokenizer is None + ): + raise ValueError("Sentence embedding model not initialized") + + device = get_module_device(self) + enc = self.sentence_emb_tokenizer( + text, + return_length=False, + return_overflowing_tokens=False, + truncation=True, + return_attention_mask=True, + max_length=self.max_length_sentence_emb, + padding=True, + return_tensors="pt", + ) + out = self.sentence_emb_text_encoder( + input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device) + ) + if self._sentence_emb_pooling_mode == "pooler_output": + # Pooler output pooling (clip-vit-large-patch14 等) + if hasattr(out, "pooler_output") and out.pooler_output is not None: + vtxt_raw = out.pooler_output.unsqueeze(1) + else: + vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state) + elif self._sentence_emb_pooling_mode == "mean": + vtxt_raw = self._encode_pooling(enc["attention_mask"].to(device), out.last_hidden_state) + elif self._sentence_emb_pooling_mode == "last_token": + vtxt_raw = self._last_token_pool(out.last_hidden_state, enc["attention_mask"].to(device)) + else: + raise ValueError(f"Unknown pooling mode: {self._sentence_emb_pooling_mode}") + + return vtxt_raw + + def encode(self, text: List[str]) -> Tuple[Tensor, Tensor, Tensor]: + ctxt_raw, ctxt_length = self.encode_llm(text=text) + vtxt_raw = self.encode_sentence_emb(text=text) + return vtxt_raw, ctxt_raw, ctxt_length + + @staticmethod + def apply_text_to_template(text: str, template: Union[str, list]) -> Union[str, list]: + if isinstance(template, str): + return template.format(text) + elif isinstance(template, list): + return [ + {"role": "system", "content": f"{template[0]['content']}"}, + {"role": "user", "content": f"{text}"}, + ] + else: + raise TypeError(f"Unsupported template type: {type(template)}") + + def _compute_crop_start(self) -> int: + if self.llm_type is None or self.llm_text_encoder is None or self.llm_tokenizer is None: + raise ValueError("LLM model not initialized") + + def _find_subseq(a: str, b: str) -> int: + for i in range(0, len(a) - len(b) + 1): + if a[i : i + len(b)] == b: + return i + return -1 + + marker = "" + if self.llm_type == "qwen3": + msgs = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"]) + s = self.llm_tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=False, enable_thinking=False + ) + else: + s = self.apply_text_to_template(marker, LLM_ENCODER_LAYOUT[self.llm_type]["template"]) + full_ids = self.llm_tokenizer(s, return_tensors="pt", add_special_tokens=True)["input_ids"][0].tolist() + marker_ids = self.llm_tokenizer(marker, return_tensors="pt", add_special_tokens=False)["input_ids"][0].tolist() + pos = _find_subseq(full_ids, marker_ids) + if pos >= 0: + return pos + else: + return max(0, len(full_ids) - 1) + + def _pad_or_truncate_tensor(self, tensor: Tensor, target_length: int, dim: int = 0) -> Tensor: + current_length = tensor.shape[dim] + if current_length > target_length: + return tensor.narrow(dim, 0, target_length) + elif current_length < target_length: + pad_shape = list(tensor.shape) + pad_shape[dim] = target_length - current_length + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor.narrow(dim, -1, 1) + return torch.cat([tensor, padding], dim=dim) + return tensor + + def _encode_pooling(self, attention_mask: Tensor, token_embeddings: Tensor) -> Tensor: + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + sentence_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + vtxt_raw = nn.functional.normalize(sentence_embeddings, p=2, dim=1).unsqueeze(1) # shape of [bs, 1, D] + return vtxt_raw + + def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + vtxt_raw = last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + vtxt_raw = last_hidden_states[ + torch.arange(batch_size, device=last_hidden_states.device), + sequence_lengths, + ] + vtxt_raw = nn.functional.normalize(vtxt_raw, p=2, dim=-1).unsqueeze(1) # shape of [bs, 1, D] + return vtxt_raw + + +if __name__ == "__main__": + # python -m hymotion.network.text_encoders.text_encoder + text_encoder = HYTextModel(llm_type="qwen3", max_length_llm=5) + vtxt_raw, ctxt_raw, ctxt_length = text_encoder.encode(["Hello, world!"]) + print(vtxt_raw.shape, ctxt_raw.shape, ctxt_length) + + crop_start = text_encoder._compute_crop_start() + print(f"crop_start: {crop_start} when using {text_encoder.llm_type}") + + assert ( + vtxt_raw.shape[1:] == (1, text_encoder.vtxt_dim) + and ctxt_raw.shape[1:] == (text_encoder._orig_max_length_llm, text_encoder.ctxt_dim) + and torch.all((ctxt_length >= 0) & (ctxt_length <= text_encoder._orig_max_length_llm)) + ), f"Got unexpected output shape: {vtxt_raw.shape}, {ctxt_raw.shape}, {ctxt_length}" diff --git a/hymotion/network/token_refiner.py b/hymotion/network/token_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..34cd4b5fb36e260da285055164847ba544bbea07 --- /dev/null +++ b/hymotion/network/token_refiner.py @@ -0,0 +1,192 @@ +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor + +from .attention import attention +from .bricks import get_norm_layer +from .encoders import MLP, MLPEncoder, TimestepEmbeddingEncoder +from .modulate_layers import ModulateDiT, apply_gate + + +class IndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + feat_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + mlp_act_type: str = "silu", + qk_norm_type: str = "layer", + qkv_bias: bool = True, + ) -> None: + super().__init__() + self.feat_dim = feat_dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.dropout = dropout + assert self.feat_dim % num_heads == 0, f"feat_dim {self.feat_dim} must be divisible by num_heads {num_heads}" + self.head_dim = feat_dim // num_heads + + self.mlp_hidden_dim = int(feat_dim * mlp_ratio) + + self.norm1 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6) + self.self_attn_qkv = nn.Linear(feat_dim, feat_dim * 3, bias=qkv_bias) + self.self_attn_q_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.self_attn_k_norm = get_norm_layer(qk_norm_type)(self.head_dim, elementwise_affine=True, eps=1e-6) + self.self_attn_proj = nn.Linear(feat_dim, feat_dim, bias=qkv_bias) + + self.norm2 = get_norm_layer(norm_type="layer")(self.feat_dim, elementwise_affine=True, eps=1e-6) + + self.mlp = MLP( + in_dim=feat_dim, + feat_dim=self.mlp_hidden_dim, + act_type=mlp_act_type, + drop=dropout, + ) + + self.adaLN_modulation = ModulateDiT( + feat_dim=feat_dim, + factor=2, + act_type="silu", + ) + + def forward(self, x: Tensor, c: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1) + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + return x + + +class IndividualTokenRefiner(nn.Module): + def __init__( + self, + feat_dim: int, + num_heads: int, + num_layers: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + mlp_act_type: str = "silu", + qk_norm_type: str = "layer", + qkv_bias: bool = True, + ) -> None: + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + feat_dim=feat_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + dropout=dropout, + mlp_act_type=mlp_act_type, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, x: Tensor, c: Tensor, mask: Optional[Tensor] = None) -> Tensor: + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + # assume the shape of self_attn_mask is [B, H, Q, K] and this is self-attention (Q==K==L) + L = self_attn_mask.size(-1) + diag = torch.eye(L, dtype=torch.bool, device=self_attn_mask.device).view(1, 1, L, L) # [1,1,L,L] + # mark which query row is "all False" (no visible key) + all_false = ~self_attn_mask.any(dim=-1, keepdim=False) # [B, H, Q] + # expand to [B, H, Q, K], only for these rows, back to diagonal visible + all_false = all_false.unsqueeze(-1).expand(-1, -1, -1, L) + self_attn_mask = torch.where(all_false, diag.expand_as(self_attn_mask), self_attn_mask) + + if self_attn_mask is not None: + self_attn_mask = torch.where( + self_attn_mask, + torch.zeros_like(self_attn_mask, dtype=torch.float), + torch.full_like(self_attn_mask, float("-inf"), dtype=torch.float), + ) + for block in self.blocks: + x = block(x, c, self_attn_mask) + return x + + +class SingleTokenRefiner(nn.Module): + def __init__( + self, + input_dim: int, + feat_dim: int, + num_heads: int, + num_layers: int, + mlp_ratio: float = 4.0, + dropout: float = 0.0, + mlp_act_type: str = "silu", + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + **kwargs, + ) -> None: + super().__init__() + self.attn_mode = attn_mode + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear(input_dim, feat_dim, bias=True) + self.context_encoder = MLPEncoder( + in_dim=feat_dim, + feat_dim=feat_dim, + num_layers=2, + act_type=mlp_act_type, + ) + self.timestep_encoder = TimestepEmbeddingEncoder( + embedding_dim=feat_dim, + feat_dim=feat_dim, + act_type=mlp_act_type, + ) + + self.individual_token_refiner = IndividualTokenRefiner( + feat_dim=feat_dim, + num_heads=num_heads, + num_layers=num_layers, + mlp_ratio=mlp_ratio, + dropout=dropout, + mlp_act_type=mlp_act_type, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + ) + + def forward(self, x: Tensor, t: Tensor, mask: Optional[Tensor] = None) -> Tensor: + timestep_aware_representations = self.timestep_encoder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.float().unsqueeze(-1) + denom = mask_float.sum(dim=1).clamp_min(1e-6) + context_aware_representations = (x * mask_float).sum(dim=1) / denom + context_aware_representations = self.context_encoder(context_aware_representations).unsqueeze(1) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + + x = self.individual_token_refiner(x, c, mask) + + return x diff --git a/hymotion/pipeline/body_model.py b/hymotion/pipeline/body_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b6977eb8c5eebbc84d88250ce2259c47405e15dd --- /dev/null +++ b/hymotion/pipeline/body_model.py @@ -0,0 +1,412 @@ +import json +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + +from ..utils.geometry import ( + rot6d_to_rotation_matrix, + rotation_matrix_to_angle_axis, +) + +# fmt: off +LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909, + -0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192, + -0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069, + -0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579, + -0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998, + 0.0266, 0.0529, 0.5356, 0.0460, -0.2774] +RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909, + 0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192, + 0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069, + -0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579, + 0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998, + -0.0266, -0.0529, 0.5356, -0.0460, 0.2774] +# fmt: on + + +def to_tensor(array, dtype=torch.float32, device=torch.device("cpu")): + if "torch.tensor" not in str(type(array)): + return torch.tensor(array, dtype=dtype).to(device) + else: + return array.to(device) + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + """Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + """ + if len(rot_vecs.shape) > 2: + rot_vec_ori = rot_vecs + rot_vecs = rot_vecs.view(-1, 3) + else: + rot_vec_ori = None + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1).view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + if rot_vec_ori is not None: + rot_mat = rot_mat.reshape(*rot_vec_ori.shape[:-1], 3, 3) + return rot_mat + + +def load_model_data(model_path): + """ + Load wooden model data from binary files. + + Args: + model_path: path to the directory containing .bin files + + Returns: + dict containing: + - v_template: (V, 3) vertex template + - j_template: (J, 3) joint template + - skin_weights: (V, 4) skin weights + - skin_indices: (V, 4) skin indices + - parents: (J,) parent indices (kintree) + - faces: (F, 3) face indices + - joint_names: list of joint names + """ + model_path = Path(model_path) + + # Load vertex template: (V*3,) -> (V, 3) + with open(model_path / "v_template.bin", "rb") as f: + v_template_flat = np.frombuffer(f.read(), dtype=np.float32) + num_verts = len(v_template_flat) // 3 + v_template = v_template_flat.reshape(num_verts, 3) + + # Load joint template: (J*3,) -> (J, 3) + with open(model_path / "j_template.bin", "rb") as f: + j_template_flat = np.frombuffer(f.read(), dtype=np.float32) + num_joints = len(j_template_flat) // 3 + j_template = j_template_flat.reshape(num_joints, 3) + + # Load skin weights: (V*4,) -> (V, 4), 4 bones per vertex + with open(model_path / "skinWeights.bin", "rb") as f: + skin_weights_flat = np.frombuffer(f.read(), dtype=np.float32) + skin_weights = skin_weights_flat.reshape(num_verts, 4) + + # Load skin indices: (V*4,) -> (V, 4), 4 bone indices per vertex + with open(model_path / "skinIndice.bin", "rb") as f: + skin_indices_flat = np.frombuffer(f.read(), dtype=np.uint16) + skin_indices = skin_indices_flat.reshape(num_verts, 4).astype(np.int64) + + # Load kintree (parent indices): (J,) + with open(model_path / "kintree.bin", "rb") as f: + parents = np.frombuffer(f.read(), dtype=np.int32) + + # Load faces + with open(model_path / "faces.bin", "rb") as f: + faces_flat = np.frombuffer(f.read(), dtype=np.uint16) + faces = faces_flat.reshape(-1, 3) + + # Load joint names + joint_names_path = model_path / "joint_names.json" + if joint_names_path.exists(): + with open(joint_names_path, "r") as f: + joint_names = json.load(f) + else: + joint_names = [f"Joint_{i}" for i in range(num_joints)] + + return { + "v_template": v_template, + "j_template": j_template, + "skin_weights": skin_weights, + "skin_indices": skin_indices, + "parents": parents, + "faces": faces, + "joint_names": joint_names, + "num_joints": num_joints, + "num_verts": num_verts, + } + + +def simple_lbs(v_template, rot_mats, joints, parents, skin_weights, skin_indices): + """ + Simple Linear Blend Skinning without shape blending. + + Args: + v_template: (V, 3) template vertices + rot_mats: (B, J, 3, 3) rotation matrices for each joint + joints: (J, 3) joint positions in rest pose + parents: (J,) parent indices for each joint + skin_weights: (V, 4) skin weights for 4 bones per vertex + skin_indices: (V, 4) bone indices for 4 bones per vertex + + Returns: + vertices: (B, V, 3) transformed vertices + posed_joints: (B, J, 3) transformed joint positions + """ + batch_size = rot_mats.shape[0] + num_joints = rot_mats.shape[1] + num_verts = v_template.shape[0] + device = rot_mats.device + dtype = rot_mats.dtype + + # Compute relative joint positions + rel_joints = joints.clone() + rel_joints[1:] = joints[1:] - joints[parents[1:]] + + # Build transformation chain: transforms_mat (B, J, 4, 4) + transforms_mat = torch.zeros(batch_size, num_joints, 4, 4, device=device, dtype=dtype) + transforms_mat[..., :3, :3] = rot_mats + transforms_mat[..., :3, 3] = rel_joints.unsqueeze(0).expand(batch_size, -1, -1) + transforms_mat[..., 3, 3] = 1.0 + + # Forward kinematics: accumulate transforms from root to each joint + transform_chain = [transforms_mat[:, 0]] + for i in range(1, num_joints): + parent_idx = parents[i].item() + curr_transform = torch.bmm(transform_chain[parent_idx], transforms_mat[:, i]) + transform_chain.append(curr_transform) + + transforms = torch.stack(transform_chain, dim=1) # (B, J, 4, 4) + + # Get posed joint positions + posed_joints = transforms[..., :3, 3].clone() # (B, J, 3) + + # Compute relative transforms (for skinning) + # We need to subtract the rest pose joint positions from the transform + rel_transforms = transforms.clone() + joints_homo = F.pad(joints, [0, 1], value=0) # (J, 4) + transformed_rest = torch.einsum("bjcd,jd->bjc", transforms[..., :3, :], joints_homo) + rel_transforms[..., :3, 3] = transforms[..., :3, 3] - transformed_rest[..., :3] + + # Apply skinning: gather transforms for each vertex's 4 bones + # skin_indices: (V, 4), skin_weights: (V, 4) + vertex_transforms = torch.zeros(batch_size, num_verts, 4, 4, 4, device=device, dtype=dtype) + for k in range(4): + bone_idx = skin_indices[:, k].long() # (V,) + vertex_transforms[:, :, k] = rel_transforms[:, bone_idx] # (B, V, 4, 4) + + # Weight the transforms + skin_weights_expanded = skin_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # (1, V, 4, 1, 1) + skin_weights_expanded = skin_weights_expanded.expand(batch_size, -1, -1, 4, 4) # (B, V, 4, 4, 4) + + weighted_transforms = (vertex_transforms * skin_weights_expanded).sum(dim=2) # (B, V, 4, 4) + + # Apply to vertices + v_homo = F.pad(v_template, [0, 1], value=1.0) # (V, 4) + vertices = torch.einsum("bvcd,vd->bvc", weighted_transforms[..., :3, :], v_homo) # (B, V, 3) + + return vertices, posed_joints + + +class WoodenMesh(torch.nn.Module): + """ + Wooden character mesh model that loads from binary files. + Uses simple LBS without shape blending (fixed skeleton). + """ + + def __init__(self, model_path="scripts/gradio/static/assets/dump_wooden"): + torch.nn.Module.__init__(self) + + # Load model data from .bin files + model = load_model_data(model_path) + + # Register buffers like original SMPLMesh + v_template = to_tensor(model["v_template"]) + self.register_buffer("v_template", v_template) + + j_template = to_tensor(model["j_template"]) + self.register_buffer("j_template", j_template) + + skin_weights = to_tensor(model["skin_weights"]) + self.register_buffer("skin_weights", skin_weights) + + skin_indices = to_tensor(model["skin_indices"], dtype=torch.long) + self.register_buffer("skin_indices", skin_indices) + + parents = to_tensor(model["parents"], dtype=torch.long) + self.register_buffer("parents", parents) + + # Store non-buffer attributes + self.faces = model["faces"] + self.joint_names = model["joint_names"] + self.num_joints = model["num_joints"] + self.num_verts = model["num_verts"] + + print(f"[WoodenMesh] Loaded model: {self.num_verts} vertices, {self.num_joints} joints") + + def forward(self, params, fast_forward=False): + """ + Forward pass to compute deformed vertices. + + Args: + params: dict containing: + - 'poses': (B, J*3) axis-angle rotations, or + - 'rot6d': (B, J, 6) 6D rotation representations + - 'trans': (B, 3) optional translation + + Returns: + dict with 'vertices' and 'vertices_wotrans' + """ + if "poses" in params: + poses = params["poses"] + batch_size = poses.shape[0] + rot_mats = batch_rodrigues(poses.view(-1, 3)).view([batch_size, -1, 3, 3]) + elif "rot6d" in params: + rot6d = params["rot6d"] + batch_size = rot6d.shape[0] + rot_mats = rot6d_to_rotation_matrix(rot6d).view([batch_size, -1, 3, 3]) + else: + raise ValueError("poses or rot6d must be in params") + + if rot_mats.shape[1] == 22: + eye = torch.eye(3, device=rot_mats.device, dtype=rot_mats.dtype)[None, None, :, :].repeat( + batch_size, 30, 1, 1 + ) + rot_mats = torch.cat([rot_mats, eye], dim=1) # (B, 22 + 30, 3, 3) + + # Simple LBS (no shape blending, fixed skeleton) + vertices, posed_joints = simple_lbs( + self.v_template, + rot_mats, + self.j_template, + self.parents, + self.skin_weights, + self.skin_indices, + ) + + # Vertices without translation (for pose-level supervision) + vertices_wotrans = vertices + + if "trans" in params: + trans = params["trans"] + vertices = vertices + trans[:, None, :] + + return { + "vertices": vertices, + "vertices_wotrans": vertices_wotrans, + "keypoints3d": posed_joints, + } + + def forward_batch(self, params): + assert "rot6d" in params and "trans" in params + rot6d = params["rot6d"] + trans = params["trans"] + bs, num_frames = rot6d.shape[:2] + rot6d_flat = rot6d.reshape(bs * num_frames, rot6d.shape[2], rot6d.shape[3]) + trans_flat = trans.reshape(bs * num_frames, trans.shape[2]) + result = self.forward( + { + "rot6d": rot6d_flat, + "trans": trans_flat, + } + ) + out = {} + for key in result: + out[key] = result[key].reshape(bs, num_frames, *result[key].shape[1:]) + return out + + +def construct_smpl_data_dict( + rot6d: Tensor, + transl: Tensor, + betas: Optional[Tensor] = None, + gender: str = "neutral", + use_default_hand_mean_pose: bool = False, +) -> dict: + rotation_matrix = rot6d_to_rotation_matrix(rot6d) + angle_axis = rotation_matrix_to_angle_axis(rotation_matrix) + left_hand_mean_pose = ( + torch.tensor( + LEFT_HAND_MEAN_AA, + device=angle_axis.device, + dtype=angle_axis.dtype, + ) + .unsqueeze(0) + .repeat(angle_axis.shape[0], 1) + .reshape(angle_axis.shape[0], -1, 3) + ) + right_hand_mean_pose = ( + torch.tensor( + RIGHT_HAND_MEAN_AA, + device=angle_axis.device, + dtype=angle_axis.dtype, + ) + .unsqueeze(0) + .repeat(angle_axis.shape[0], 1) + .reshape(angle_axis.shape[0], -1, 3) + ) + if angle_axis.shape[1] == 22: + angle_axis = torch.cat( + [ + angle_axis, + left_hand_mean_pose, + right_hand_mean_pose, + ], + dim=1, + ) + elif angle_axis.shape[1] == 52: + if use_default_hand_mean_pose: + angle_axis = torch.cat( + [ + angle_axis[:, :22], + left_hand_mean_pose, + right_hand_mean_pose, + ], + dim=1, + ) + else: + angle_axis = angle_axis + + assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}" + dump = { + "betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)), + "gender": gender, + "poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1), + "trans": transl.cpu().numpy(), + "mocap_framerate": 30, + "num_frames": angle_axis.shape[0], + "Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3], + } + return dump + + +if __name__ == "__main__": + # python -m hymotion.pipeline.body_model + model_path = "scripts/gradio/static/assets/dump_wooden" + model = WoodenMesh(model_path) + params = { + "rot6d": torch.randn(1, 52, 6), + "trans": torch.randn(1, 3), + } + result = model(params) + print(result.keys()) + print(result["vertices"].shape) + print(result["vertices_wotrans"].shape) + print(result["keypoints3d"].shape) + params_batch = { + "rot6d": torch.randn(3, 100, 22, 6), + "trans": torch.randn(3, 100, 3), + } + result_batch = model.forward_batch(params_batch) + print(result_batch.keys()) + print(result_batch["vertices"].shape) + print(result_batch["vertices_wotrans"].shape) + print(result_batch["keypoints3d"].shape) diff --git a/hymotion/pipeline/motion_diffusion.py b/hymotion/pipeline/motion_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c62f5fa343694fdbe9c33f4f496f9fa89465e88b --- /dev/null +++ b/hymotion/pipeline/motion_diffusion.py @@ -0,0 +1,673 @@ +import os +import os.path as osp +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from scipy.signal import savgol_filter +from torch import Tensor +from torchdiffeq import odeint + +from ..utils.geometry import ( + matrix_to_quaternion, + quaternion_fix_continuity, + quaternion_to_matrix, + rot6d_to_rotation_matrix, + rotation_matrix_to_rot6d, +) +from ..utils.loaders import load_object +from ..utils.motion_process import smooth_rotation +from ..utils.type_converter import get_module_device +from .body_model import WoodenMesh + + +def length_to_mask(lengths: Tensor, max_len: int) -> Tensor: + """ + lengths: (B, 1) + max_len: int + Returns: (B, max_len) + """ + assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}" + if lengths.ndim == 1: + lengths = lengths.unsqueeze(1) + mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths + return mask + + +def start_end_frame_to_mask(start_frame: Tensor, end_frame: Tensor, max_len: int) -> Tensor: + # 生成一个 (B, max_len) 的mask,只有在[start_frame, end_frame]区间内为True,其余为False + assert (start_frame >= 0).all() and (end_frame >= 0).all(), f"start_frame={start_frame}, end_frame={end_frame}" + lengths = end_frame - start_frame + 1 + assert lengths.max() <= max_len, f"lengths.max()={lengths.max()} > max_len={max_len}" + if lengths.ndim == 1: + lengths = lengths.unsqueeze(1) + batch_size = start_frame.shape[0] + arange_ids = torch.arange(max_len, device=start_frame.device).unsqueeze(0).expand(batch_size, max_len) + mask = (arange_ids >= start_frame.unsqueeze(1)) & (arange_ids <= end_frame.unsqueeze(1)) + return mask + + +def randn_tensor( + shape, + generator=None, + device=None, + dtype=None, + layout=None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. + + When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the + tensor is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + print( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + + +class MotionGeneration(torch.nn.Module): + def __init__( + self, + network_module: str, + network_module_args: dict, + text_encoder_module: str, + text_encoder_cfg: dict, + mean_std_dir: str, + motion_type="auto", + **kwargs, + ): + super().__init__() + # build models and parameters + self._network_module_args = deepcopy(network_module_args) + self.motion_transformer = load_object(network_module, network_module_args) + self._text_encoder_module = text_encoder_module + self._text_encoder_cfg = deepcopy(text_encoder_cfg) + self.motion_type = motion_type + + self.null_vtxt_feat = torch.nn.Parameter( + torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768)) + ) + self.null_ctxt_input = torch.nn.Parameter( + torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096)) + ) + self.special_game_vtxt_feat = torch.nn.Parameter( + torch.randn(1, 1, self._network_module_args.get("vtxt_input_dim", 768)) + ) + self.special_game_ctxt_feat = torch.nn.Parameter( + torch.randn(1, 1, self._network_module_args.get("ctxt_input_dim", 4096)) + ) + # build buffer + self.mean_std_dir = mean_std_dir + self._parse_buffer(self.motion_type) + + self.output_mesh_fps = kwargs.get("output_mesh_fps", 30) + self.train_frames = kwargs.get("train_frames", 360) + self.uncondition_mode = kwargs.get("uncondition_mode", False) + self.enable_ctxt_null_feat = kwargs.get("enable_ctxt_null_feat", False) + self.enable_special_game_feat = kwargs.get("enable_special_game_feat", False) + self.random_generator_on_gpu = kwargs.get("random_generator_on_gpu", True) + + def _parse_buffer(self, mode: str) -> None: + self.body_model = WoodenMesh() + self._find_motion_type(mode=mode) + self._load_mean_std() + + def _load_mean_std(self, mean_std_name: Optional[str] = None) -> None: + mean_std_name = self.mean_std_dir if mean_std_name is None else mean_std_name + if mean_std_name is not None and osp.isdir(mean_std_name): + mean = torch.from_numpy(np.load(osp.join(mean_std_name, "Mean.npy"))).float() + std = torch.from_numpy(np.load(osp.join(mean_std_name, "Std.npy"))).float() + self._assert_motion_dimension(mean.unsqueeze(0), std.unsqueeze(0)) + self.register_buffer("mean", mean) + self.register_buffer("std", std) + else: + print( + f"[{self.__class__.__name__}] No mean_std found, using blank mean_std, " + f"self.mean_std_dir={self.mean_std_dir}" + ) + self.register_buffer("mean", torch.zeros(1)) + self.register_buffer("std", torch.ones(1)) + + def _assert_motion_dimension(self, mean: Tensor, std: Tensor) -> None: + assert mean.shape == std.shape, f"mean.shape={mean.shape} != std.shape={std.shape}" + assert mean.ndim == 2, f"mean.ndim={mean.ndim} != 2" + assert mean.shape == (1, 201), f"mean.shape={mean.shape} != (1, 201)" + + def _find_motion_type(self, mode: str) -> None: + if mode == "auto": + self.motion_type = "o6dp" + else: + self.motion_type = mode + + def set_epoch(self, epoch) -> None: + self.current_epoch = epoch + + def load_in_demo( + self, + ckpt_name: str, + mean_std_name: Optional[str] = None, + build_text_encoder: bool = True, + allow_empty_ckpt: bool = False, + ) -> None: + if not allow_empty_ckpt: + if not os.path.exists(ckpt_name): + import warnings + warnings.warn(f"Checkpoint {ckpt_name} not found, skipping model loading") + else: + checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False) + self.load_state_dict(checkpoint["model_state_dict"], strict=False) + if mean_std_name is not None: + assert os.path.exists(mean_std_name), f"{mean_std_name} not found" + if not os.path.isfile(mean_std_name): + mean_std_name = None + self._load_mean_std(mean_std_name) + self.motion_transformer.eval() + if build_text_encoder and not self.uncondition_mode: + self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg) + self.text_encoder.to(get_module_device(self)) + + @torch.no_grad() + def encode_text(self, text: Dict[str, List[str]]) -> Dict[str, Tensor]: + if not hasattr(self, "text_encoder"): + self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg) + self.text_encoder.to(get_module_device(self)) + text = text["text"] + vtxt_input, ctxt_input, ctxt_length = self.text_encoder.encode(text=text) + return { + "text_vec_raw": vtxt_input, + "text_ctxt_raw": ctxt_input, + "text_ctxt_raw_length": ctxt_length, + } + + def decode_motion_from_latent(self, latent: Tensor, should_apply_smooothing: bool = True) -> Dict[str, Tensor]: + std_zero = self.std < 1e-3 + std = torch.where(std_zero, torch.zeros_like(self.std), self.std) + latent_denorm = latent * std + self.mean + return self._decode_o6dp( + latent_denorm, + num_joints=22, + rel_trans=False, + should_apply_smooothing=should_apply_smooothing, + ) + + def _forward_smpl_batch( + self, + root_rot6d: Tensor, # (B, L, 1, 6) + body_rot6d: Tensor, # (B, L, 21, 6) + transl: Tensor, # (B, L, 3) + left_hand_pose: Optional[Tensor] = None, # (B, L, 15, 6) + right_hand_pose: Optional[Tensor] = None, # (B, L, 16, 6) + ) -> Tensor: + device = transl.device + bsz, L = transl.shape[:2] + k3d_all = [] + tmp_betas = torch.zeros(1, 16, device=device) + for bs in range(bsz): + out = self.body_model( + body_rot6d[bs], + tmp_betas, + root_rot6d[bs], + transl[bs], + left_hand_pose=(left_hand_pose[bs] if left_hand_pose is not None else None), + right_hand_pose=(right_hand_pose[bs] if right_hand_pose is not None else None), + ) + k3d_all.append(out.detach().cpu()) + return torch.stack(k3d_all, dim=0) # (B, L, J, 3) + + def _decode_o6dp( + self, + latent_denorm: torch.Tensor, + num_joints: int, + rel_trans: bool = False, + should_apply_smooothing: bool = True, + ) -> dict: + device = get_module_device(self) + B, L = latent_denorm.shape[:2] + nj = num_joints + body_n = nj - 1 + + if not rel_trans: + transl = latent_denorm[..., 0:3].clone() + else: + transl = torch.cumsum(latent_denorm[..., 0:3].clone(), dim=1) / self.output_mesh_fps + root_rot6d = latent_denorm[..., 3:9].reshape(B, L, 1, 6).clone() + + body6d_start = 9 + body6d_end = body6d_start + body_n * 6 + body_rot6d_full = latent_denorm[..., body6d_start:body6d_end].clone().reshape(B, L, body_n, 6) + + # 52 joints need to be split into hands + left_hand_pose = right_hand_pose = None + if nj == 52: + body_rot6d = body_rot6d_full[:, :, :21, :].clone() + left_hand_pose = body_rot6d_full[:, :, 21:36, :].clone() + right_hand_pose = body_rot6d_full[:, :, 36:51, :].clone() + else: + body_rot6d = body_rot6d_full + + if left_hand_pose is not None and right_hand_pose is not None: + body_full = torch.cat([body_rot6d, left_hand_pose, right_hand_pose], dim=2) + else: + body_full = body_rot6d + rot6d = torch.cat([root_rot6d, body_full], dim=2) # (B, L, nj, 6) + if should_apply_smooothing: + # only apply slerp smoothing to the first 22 joints (non-finger joints) + rot6d_body = rot6d[:, :, :22, :] # (B, L, 22, 6) + rot6d_fingers = rot6d[:, :, 22:, :] # (B, L, J-22, 6) + rot6d_body_smooth = self.smooth_with_slerp(rot6d_body, sigma=1.0) + rot6d_smooth = torch.cat([rot6d_body_smooth, rot6d_fingers], dim=2) + else: + rot6d_smooth = rot6d + root_rotmat_smooth = rot6d_to_rotation_matrix(rot6d_smooth[:, :, 0, :]) # (B, L, 3, 3) + + transl_fixed = transl.detach() + if should_apply_smooothing: + transl_smooth = self.smooth_with_savgol(transl_fixed.detach(), window_length=11, polyorder=5) + else: + transl_smooth = transl_fixed + + if self.body_model is not None: + print(f'{self.__class__.__name__} rot6d_smooth shape: {rot6d_smooth.shape}, transl_smooth shape: {transl_smooth.shape}') + with torch.no_grad(): + vertices_all = [] + k3d_all = [] + for bs in range(rot6d_smooth.shape[0]): + out = self.body_model.forward( + { + 'rot6d': rot6d_smooth[bs], + 'trans': transl_smooth[bs], + } + ) + vertices_all.append(out["vertices"]) + k3d_all.append(out['keypoints3d']) + vertices = torch.stack(vertices_all, dim=0) + k3d = torch.stack(k3d_all, dim=0) + print(f'{self.__class__.__name__} vertices shape: {vertices.shape}, k3d shape: {k3d.shape}') + # k3d = self._forward_smpl_batch( + # rot6d_smooth[:, :, 0:1, :].to(device), + # rot6d_smooth[:, :, 1:22, :].to(device), + # transl_smooth, + # left_hand_pose=(rot6d_smooth[:, :, 22:37, :].to(device) if left_hand_pose is not None else None), + # right_hand_pose=(rot6d_smooth[:, :, 37:52, :].to(device) if right_hand_pose is not None else None), + # ) + # align with the ground + min_y = vertices[..., 1].amin(dim=(1, 2), keepdim=True) # (B, 1, 1) + print(f'{self.__class__.__name__} min_y: {min_y}') + k3d = k3d.clone() + k3d[..., 1] -= min_y # (B, L, J) - (B, 1, 1) + transl_smooth = transl_smooth.clone() + transl_smooth[..., 1] -= min_y.squeeze(-1).to(device) # (B, L) - (B, 1) + else: + k3d = torch.zeros(B, L, nj, 3, device=device) + + return dict( + latent_denorm=latent_denorm, # (B, L, 201) + keypoints3d=k3d, # (B, L, J, 3) + rot6d=rot6d_smooth, # (B, L, J, 6) + transl=transl_smooth, # (B, L, 3) + root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3) + ) + + @staticmethod + def smooth_with_savgol(input: torch.Tensor, window_length: int = 9, polyorder: int = 5) -> torch.Tensor: + if len(input.shape) == 2: + is_batch = False + input = input.unsqueeze(0) + else: + is_batch = True + input_np = input.cpu().numpy() + input_smooth_np = np.empty_like(input_np, dtype=np.float32) + for b in range(input_np.shape[0]): + for j in range(input_np.shape[2]): + input_smooth_np[b, :, j] = savgol_filter(input_np[b, :, j], window_length, polyorder) + input_smooth = torch.from_numpy(input_smooth_np).to(input) + if not is_batch: + input_smooth = input_smooth.squeeze(0) + return input_smooth + + @staticmethod + def smooth_with_slerp(input: torch.Tensor, sigma: float = 1.0) -> torch.Tensor: + def fix_time_continuity(q: Tensor, time_dim: int = -3): + shape = q.shape + qv = q.moveaxis(time_dim, 0).contiguous().view(shape[time_dim], -1, 4) + qv = quaternion_fix_continuity(qv) + return qv.view(shape[time_dim], *shape[:time_dim], *shape[time_dim + 1 :]).moveaxis(0, time_dim) + + num_joints = input.shape[2] + RR = rot6d_to_rotation_matrix(input) + qq = matrix_to_quaternion(RR) + qq_np = fix_time_continuity(qq, time_dim=1).cpu().numpy() + qq_s_np = smooth_rotation( + qq_np, + sigma=sigma, + ) + input_smooth = rotation_matrix_to_rot6d(quaternion_to_matrix(torch.from_numpy(qq_s_np))) + return input_smooth.to(input.device) + + @staticmethod + def noise_from_seeds( + latent: Tensor, seeds: Union[int, List[int]], seed_start: int = 0, random_generator_on_gpu: bool = True + ) -> Tensor: + if isinstance(seeds, int): + seeds = list(range(seeds)) + noise_list = [] + B = latent.shape[0] + shape = (B, *latent.shape[1:]) + for seed in seeds: + if random_generator_on_gpu: + generator = torch.Generator(device=latent.device).manual_seed(seed + seed_start) + noise_sample = randn_tensor(shape, generator=generator, device=latent.device, dtype=latent.dtype) + else: + generator = torch.Generator().manual_seed(seed + seed_start) + noise_sample = randn_tensor(shape, generator=generator, dtype=latent.dtype).to(latent.device) + noise_list.append(noise_sample) + return torch.cat(noise_list, dim=0) + + def _maybe_inject_source_token( + self, + vtxt_input: Tensor, + ctxt_input: Tensor, + ctxt_mask_temporal: Tensor, + sources: Optional[List[str]], + trigger_sources: Optional[set] = None, + prob: float = 0.5, + ) -> Tuple[Tensor, Tensor, Tensor]: + if (sources is None or trigger_sources is None) or not self.enable_special_game_feat: + return vtxt_input, ctxt_input, ctxt_mask_temporal + + B, Lc, Dc = ctxt_input.shape + assert ( + isinstance(sources, (list, tuple)) and len(sources) == B + ), f"sources length should be equal to batch: {len(sources)} vs {B}" + + trig = set(s.lower() for s in trigger_sources) + src_mask = torch.tensor( + [str(s).lower() in trig for s in sources], dtype=torch.bool, device=ctxt_input.device + ) # (B,) + if not src_mask.any(): + return vtxt_input, ctxt_input, ctxt_mask_temporal + + rand_mask = ( + torch.rand(B, device=ctxt_input.device) < prob + if self.training + else torch.BoolTensor(B).fill_(True).to(ctxt_input.device) + ) + apply_mask = src_mask & rand_mask + if not apply_mask.any(): + return vtxt_input, ctxt_input, ctxt_mask_temporal + + # vtxt: only add mixture to the hit samples + vtxt_token = self.special_game_vtxt_feat.to(vtxt_input).expand(B, 1, -1) + vtxt_input = vtxt_input + vtxt_token * apply_mask.view(B, 1, 1).to(vtxt_input.dtype) + + # calculate the current effective length of each sample + if ctxt_mask_temporal.dtype == torch.bool: + cur_len = ctxt_mask_temporal.sum(dim=1).long() # (B,) + else: + cur_len = (ctxt_mask_temporal > 0).sum(dim=1).long() + + # for the "not full" hit samples, + # write the special token at the cur_len position, + # and set the mask to True + can_inplace = apply_mask & (cur_len < Lc) + b_inplace = torch.nonzero(can_inplace, as_tuple=False).squeeze(1) # (K,) + if b_inplace.numel() > 0: + pos = cur_len[b_inplace] # (K,) + token = self.special_game_ctxt_feat.squeeze(0).squeeze(0).to(ctxt_input) # (Dc,) + ctxt_input[b_inplace, pos, :] = token.unsqueeze(0).expand(b_inplace.numel(), Dc) + if ctxt_mask_temporal.dtype == torch.bool: + ctxt_mask_temporal[b_inplace, pos] = True + else: + ctxt_mask_temporal[b_inplace, pos] = 1 + + # if there are "full" hit samples, need to pad one: + # the full samples write the special token at the new position, + # other samples pad zero and mask=False + need_expand = (apply_mask & (cur_len >= Lc)).any() + if need_expand: + suffix = torch.zeros((B, 1, Dc), dtype=ctxt_input.dtype, device=ctxt_input.device) + full_hit = apply_mask & (cur_len >= Lc) + b_full = torch.nonzero(full_hit, as_tuple=False).squeeze(1) + if b_full.numel() > 0: + suffix[b_full, 0, :] = ( + self.special_game_ctxt_feat.expand(b_full.numel(), 1, -1).to(ctxt_input).squeeze(1) + ) + ctxt_input = torch.cat([ctxt_input, suffix], dim=1) + + if ctxt_mask_temporal.dtype == torch.bool: + suffix_mask = torch.zeros((B, 1), dtype=torch.bool, device=ctxt_input.device) + suffix_mask[b_full, 0] = True + else: + suffix_mask = torch.zeros((B, 1), dtype=ctxt_mask_temporal.dtype, device=ctxt_input.device) + suffix_mask[b_full, 0] = 1 + ctxt_mask_temporal = torch.cat([ctxt_mask_temporal, suffix_mask], dim=1) + + return vtxt_input, ctxt_input, ctxt_mask_temporal + + +class MotionFlowMatching(MotionGeneration): + def __init__( + self, + network_module: str, + network_module_args: dict, + text_encoder_module: str, + text_encoder_cfg: dict, + noise_scheduler_cfg: dict = {"method": "euler"}, + infer_noise_scheduler_cfg: dict = {"validation_steps": 50}, + mean_std_dir: Optional[str] = None, + losses_cfg: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + **kwargs, + ): + super().__init__( + network_module=network_module, + network_module_args=network_module_args, + text_encoder_module=text_encoder_module, + text_encoder_cfg=text_encoder_cfg, + losses_cfg=losses_cfg, + mean_std_dir=(mean_std_dir if mean_std_dir is not None else test_cfg.get("mean_std_dir", None)), + **kwargs, + ) + # build scheduler + self._noise_scheduler_cfg = deepcopy(noise_scheduler_cfg) + self._infer_noise_scheduler_cfg = deepcopy(infer_noise_scheduler_cfg) + # additional cfg + self.train_cfg = deepcopy(train_cfg) if train_cfg else dict() + self.test_cfg = deepcopy(test_cfg) if test_cfg else dict() + self._parse_test_cfg() + + def _parse_test_cfg(self) -> None: + self.validation_steps = self._infer_noise_scheduler_cfg["validation_steps"] + self.text_guidance_scale = self.test_cfg.get("text_guidance_scale", 1) + + @torch.no_grad() + def generate( + self, + text: Union[str, List[str]], + seed_input: List[int], + duration_slider: int, + cfg_scale: Optional[float] = None, + use_special_game_feat: bool = False, + hidden_state_dict=None, + length=None, + ) -> Dict[str, Any]: + device = get_module_device(self) + if length is None: + length = int(round(duration_slider * self.output_mesh_fps)) + assert ( + 0 < length < 5000 + ), f"input duration_slider must be in (0, {5000/self.output_mesh_fps}] due to rope, but got {duration_slider}" + if length > self.train_frames or length < min(self.train_frames, 20): + print(f">>> given length is too long or too short, got {length}, will be truncated") + length = min(length, self.train_frames) + length = max(length, min(self.train_frames, 20)) + + repeat = len(seed_input) + if isinstance(text, list): + assert len(text) == repeat, f"len(text) must equal len(seed_input), got {len(text)} vs {repeat}" + text_list = text + elif isinstance(text, str): + text_list = [text] * repeat + else: + raise TypeError(f"Unsupported text type: {type(text)}") + + if not self.uncondition_mode: + if hidden_state_dict is None: + hidden_state_dict = self.encode_text({"text": text_list}) + vtxt_input = hidden_state_dict["text_vec_raw"] + ctxt_input = hidden_state_dict["text_ctxt_raw"] + ctxt_length = hidden_state_dict["text_ctxt_raw_length"] + # check shape + if len(vtxt_input.shape) == 2 and len(ctxt_input.shape) == 2: + vtxt_input = vtxt_input[None].repeat(repeat, 1, 1) + ctxt_input = ctxt_input[None].repeat(repeat, 1, 1) + ctxt_length = ctxt_length.repeat(repeat) + ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]) + sources = None if not use_special_game_feat else ["Game"] * repeat + vtxt_input, ctxt_input, ctxt_mask_temporal = self._maybe_inject_source_token( + vtxt_input, ctxt_input, ctxt_mask_temporal, sources, trigger_sources={"Taobao", "Game"} + ) + else: + vtxt_input = self.null_vtxt_feat.expand(repeat, 1, -1) + ctxt_input = self.null_ctxt_input.expand(repeat, 1, -1) + ctxt_length = torch.tensor([1]).expand(repeat) + ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]).expand(repeat, -1) + assert len(vtxt_input.shape) == 3, f"vtxt_input.shape: {vtxt_input.shape}, should be (B, 1, D)" + assert len(ctxt_input.shape) == 3, f"ctxt_input.shape: {ctxt_input.shape}, should be (B, 1, D)" + assert len(ctxt_length.shape) == 1, f"ctxt_length.shape: {ctxt_length.shape}, should be (B,)" + + ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]) + x_length = torch.LongTensor([length] * repeat).to(device) + x_mask_temporal = length_to_mask(x_length, self.train_frames) + + text_guidance_scale = cfg_scale if cfg_scale is not None else self.text_guidance_scale + do_classifier_free_guidance = text_guidance_scale > 1.0 and not self.uncondition_mode + if do_classifier_free_guidance is True: + silent_text_feat = self.null_vtxt_feat.expand(*vtxt_input.shape) + vtxt_input = torch.cat([silent_text_feat, vtxt_input], dim=0) + + if self.enable_ctxt_null_feat: + silent_ctxt_input = self.null_ctxt_input.expand(*ctxt_input.shape) + else: + silent_ctxt_input = ctxt_input + ctxt_input = torch.cat([silent_ctxt_input, ctxt_input], dim=0) + + ctxt_mask_temporal = torch.cat([ctxt_mask_temporal] * 2, dim=0) + x_mask_temporal = torch.cat([x_mask_temporal] * 2, dim=0) + + def fn(t: Tensor, x: Tensor) -> Tensor: + # predict flow + x_input = torch.cat([x] * 2, dim=0) if do_classifier_free_guidance else x + x_pred = self.motion_transformer( + x=x_input, + ctxt_input=ctxt_input, + vtxt_input=vtxt_input, + timesteps=t.expand(x_input.shape[0]), + x_mask_temporal=x_mask_temporal, + ctxt_mask_temporal=ctxt_mask_temporal, + ) + if do_classifier_free_guidance: + x_pred_basic, x_pred_text = x_pred.chunk(2, dim=0) + x_pred = x_pred_basic + text_guidance_scale * (x_pred_text - x_pred_basic) + return x_pred + + # duplicate test corner for inner time step oberservation + t = torch.linspace(0, 1, self.validation_steps + 1, device=device) + y0 = self.noise_from_seeds( + torch.zeros( + 1, + self.train_frames, + self._network_module_args["input_dim"], + device=device, + ), + seed_input, + random_generator_on_gpu=self.random_generator_on_gpu, + ) + with torch.no_grad(): + trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg) + sampled = trajectory[-1] + assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}" + sampled = sampled[:, :length, ...].clone() + + output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True) + + return { + **output_dict, + "text": text, + } + + +if __name__ == "__main__": + # python -m hymotion.pipeline.motion_diffusion + import time + + import torch + + device = "cuda:0" + bsz, input_dim = 64, 272 + seq_lens = [90, 180, 360] + ctxt_seq_lens = 64 + warmup = 5 + repeats = 100 + + network_module = "hymotion/network/hymotion_mmdit.HunyuanMotionMMDiT" + network_module_args = { + "input_dim": input_dim, + "feat_dim": 512, + "ctxt_input_dim": 4096, + "vtxt_input_dim": 768, + "num_layers": 12, + "num_heads": 4, + "mlp_ratio": 2.0, + "dropout": 0.0, + "mask_mode": "narrowband", + } + text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel" + text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": ctxt_seq_lens} + + # ================================ FM_MMDiT ================================ + FM_MMDiT = MotionFlowMatching( + network_module=network_module, + network_module_args=network_module_args, + text_encoder_module=text_encoder_module, + text_encoder_cfg=text_encoder_cfg, + noise_scheduler_module={"method": "euler"}, + infer_noise_scheduler_cfg={"validation_steps": 50}, + train_cfg={"cond_mask_prob": 0.1}, + test_cfg={ + "text_guidance_scale": 1.5, + }, + ).to(device) diff --git a/hymotion/prompt_engineering/model_constants.py b/hymotion/prompt_engineering/model_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a04a8badaa8c1c6efc31c730afef2b93489d3407 --- /dev/null +++ b/hymotion/prompt_engineering/model_constants.py @@ -0,0 +1,42 @@ +__all__ = [ + "REWRITE_AND_INFER_TIME_PROMPT_FORMAT", +] + +REWRITE_AND_INFER_TIME_PROMPT_FORMAT = """ + # Role + You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems. + + # Task + Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption. + + # Instructions + + ### 1. Duration Estimation (frame_count) + - Analyze the complexity, speed, and physical constraints of the described action. + - Estimate the time required to perform the action in a **smooth, natural, and realistic manner**. + - Calculate the total duration in frames based on a **30 fps** (frames per second) standard. + - Output strictly as an Integer. + + ### 2. Caption Refinement (short_caption) + - Generate a refined, grammatically correct version of the input description in **English**. + - **Strict Constraints**: + - You must **PRESERVE** the original sequence of events (chronological order). + - You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly"). + - **DO NOT** add new sub-actions or hallucinate details not present in the input. + - **DO NOT** delete any specific movements. + - The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request. + + ### 3. Output Format + - Return **ONLY** a raw JSON object. + - Do not use Markdown formatting (i.e., do not use ```json ... ```). + - Ensure the JSON is valid and parsable. + + # JSON Structure + {{ + "duration": , + "short_caption": "" + }} + + # Input + {} +""" diff --git a/hymotion/prompt_engineering/prompt_rewrite.py b/hymotion/prompt_engineering/prompt_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..cb871fc558801650efbdc04b1b9dcf50a89290ce --- /dev/null +++ b/hymotion/prompt_engineering/prompt_rewrite.py @@ -0,0 +1,304 @@ +# prompt_rewrite.py +import base64 +import concurrent.futures +import datetime +import hashlib +import hmac +import json +import logging +import random +import re +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from openai import OpenAI +from requests import exceptions as req_exc + +from .model_constants import REWRITE_AND_INFER_TIME_PROMPT_FORMAT + +# logging.basicConfig(level=logging.INFO) + + +@dataclass +class ApiConfig: + host: str + user: str + apikey: str + model: str + api_version: Optional[str] = None + timeout: int = 3600 + source: str = "hymotion" + + +@dataclass +class RetryConfig: + max_retries: int = 20 + base_delay: float = 1.0 + timeout: float = 30.0 + retry_status: Tuple[int, ...] = (429, 500, 502, 503, 504) + max_delay: float = 1.0 + + +class ApiError(Exception): + pass + + +class ResponseParseError(Exception): + pass + + +class OpenAIChatApi: + def __init__(self, config: ApiConfig) -> None: + self.logger = logging.getLogger(__name__) + self.config = config + self.client = OpenAI( + api_key=self.config.apikey, + base_url=self.config.host, + ) + + def call_data_eval(self, data: Union[str, Dict[str, Any]]): + if isinstance(data, dict) and "messages" in data: + raw_msgs = data["messages"] + messages: List[Dict[str, str]] = [] + for m in raw_msgs: + role = m.get("role", "user") + content = m.get("content", "") + if isinstance(content, list): + parts = [] + for p in content: + if isinstance(p, dict) and ("text" in p): + parts.append(str(p.get("text", ""))) + content = " ".join([t for t in parts if t]) + elif not isinstance(content, str): + content = str(content) + messages.append({"role": role, "content": content}) + payload = { + "model": self.config.model, + "messages": messages, + "temperature": 0.7, + "top_p": 0.8, + } + for k in ( + "temperature", + "top_p", + "max_tokens", + "n", + "stop", + "presence_penalty", + "frequency_penalty", + "user", + ): + if k in data: + payload[k] = data[k] + else: + payload = { + "model": self.config.model, + "messages": [{"role": "user", "content": str(data)}], + "temperature": 0.7, + "top_p": 0.8, + } + try: + resp = self.client.chat.completions.create(**payload) + return resp + except Exception as e: + self.logger.error(f"OpenAI API call failed: {e}") + raise ApiError(f"OpenAI API call failed: {e}") from e + + +class ResponseParser: + def __init__(self): + self.logger = logging.getLogger(__name__) + + def call_data_eval_with_retry( + self, api: Union[OpenAIChatApi], data: str, retry_config: Optional[RetryConfig] = None + ) -> Tuple[Union[Dict[str, Any], int], float, float]: + if retry_config is None: + retry_config = RetryConfig() + + last_error = None + for attempt in range(retry_config.max_retries): + start_time = time.time() + cost = 0.0 + + try: + result = self._execute_request(api, data) + end_time = time.time() + parsed_result = self._parse_answer(result) + self._validate_result(parsed_result) + return parsed_result, cost, end_time - start_time + + except ( + concurrent.futures.TimeoutError, + req_exc.RequestException, + json.JSONDecodeError, + ValueError, + TypeError, + ResponseParseError, + ) as e: + last_error = e + self.logger.warning(f"Attempt {attempt + 1} failed: {e}") + if isinstance(e, req_exc.RequestException) and hasattr(e, "response"): + if e.response is not None and e.response.status_code not in retry_config.retry_status: + raise ApiError(f"Non-retryable error: {e.response.status_code}") from e + if attempt < retry_config.max_retries - 1: + delay = self._calculate_delay(attempt, retry_config) + self.logger.info(f"JSON parsing failed, {delay:.1f} seconds later retry...") + time.sleep(delay) + + raise ApiError(f"Retry {retry_config.max_retries} times but still failed") from last_error + + def _execute_request(self, api: Union[OpenAIChatApi], data: str) -> Dict[str, Any]: + response = api.call_data_eval(data) + + try: + if hasattr(response, "model_dump"): + return response.model_dump() + if isinstance(response, dict): + return response + if hasattr(response, "__dict__"): + return json.loads(json.dumps(response.__dict__, default=str)) + except Exception as e: + raise ResponseParseError(f"Unable to parse OpenAI returned object: {type(response)} - {e}") from e + + raise ResponseParseError(f"Unknown response type: {type(response)}") + + def _extract_cost(self, payload: Dict[str, Any]) -> float: + try: + return float(payload.get("cost_info", {}).get("cost", 0)) / 1e6 + except (AttributeError, KeyError): + return 0.0 + + def _validate_result(self, result: Union[Dict[str, Any], int]) -> None: + if isinstance(result, int): + return + elif isinstance(result, dict): + required_fields = ["duration", "short_caption"] + for field in required_fields: + if not isinstance(result.get(field), (int, str)): + raise ResponseParseError(f"LLM returned invalid format: {field}") + else: + raise ResponseParseError(f"Unsupported answer type: {type(result)}") + + def _calculate_delay(self, attempt: int, config: RetryConfig) -> float: + delay = config.base_delay * (2**attempt) * (0.5 + random.random()) + return min(delay, config.max_delay) + + def _parse_answer(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if isinstance(payload, dict) and "choices" in payload: + return self._parse_from_choices_field(payload) + + raise ResponseParseError("Unknown response format: expected choices") + + def _parse_from_choices_field(self, payload: Dict[str, Any]) -> Dict[str, Any]: + choices = payload.get("choices") or [] + if not choices: + raise ResponseParseError("OpenAI returned empty") + + content = self._extract_content_from_choice(choices[0]) + + if not isinstance(content, str) or not content.strip(): + raise ResponseParseError("OpenAI returned no valid content") + + return self._parse_json_content(content) + + def _extract_content_from_choice(self, choice: Any) -> Optional[str]: + content = None + + if isinstance(choice, dict): + # Try message content first + msg = choice.get("message") or {} + content = msg.get("content") + # Fallback to delta content or text + if content is None: + delta = choice.get("delta") or {} + content = delta.get("content", choice.get("text")) + else: + # Handle object-like choice (e.g. Pydantic model) + msg = getattr(choice, "message", None) + if msg is not None: + content = getattr(msg, "content", None) + + if content is None: + delta = getattr(choice, "delta", None) + if delta is not None: + content = getattr(delta, "content", None) + + if content is None: + content = getattr(choice, "text", None) + + return content + + def _parse_json_content(self, content: str) -> Dict[str, Any]: + cleaned = self._cleanup_fenced_json(content) + try: + return json.loads(cleaned) + except json.JSONDecodeError as e: + self.logger.warning(f"JSON parsing failed, original content: {cleaned[:500]}...") + raise ResponseParseError(f"JSON parsing failed: {e}") from e + + def _cleanup_fenced_json(self, text: str) -> str: + text = text.strip() + if text.startswith("```"): + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + if not text.lstrip().startswith("{") and "{" in text and "}" in text: + start = text.find("{") + end = text.rfind("}") + if 0 <= start < end: + text = text[start : end + 1] + return text + + +class PromptRewriter: + def __init__( + self, + host: Optional[str] = None, + parser: Optional[ResponseParser] = None, + backend: Literal["our_rewriter"] = "our_rewriter", + ): + self.parser = parser or ResponseParser() + self.logger = logging.getLogger(__name__) + self.backend = backend.lower() + + if self.backend == "our_rewriter": + self.api = OpenAIChatApi( + ApiConfig( + host=host, + user="", + apikey="EMPTY", + model="Qwen3-30B-A3B-SFT", + api_version="", + ) + ) + else: + raise ValueError(f"Invalid backend: {self.backend}") + + def rewrite_prompt_and_infer_time( + self, + text: str, + prompt_format: str = REWRITE_AND_INFER_TIME_PROMPT_FORMAT, + retry_config: Optional[RetryConfig] = None, + ) -> Tuple[float, str]: + self.logger.info("Start rewriting prompt...") + try: + result, cost, elapsed = self.parser.call_data_eval_with_retry( + self.api, prompt_format.format(text), retry_config + ) + self.logger.info(f"Rewriting completed - cost: {cost:.6f}, time: {elapsed:.2f}s") + return round(float(result["duration"]) / 30.0, 2), result["short_caption"] + + except Exception as e: + self.logger.error(f"Prompt rewriting failed: {e}") + raise + + +if __name__ == "__main__": + # python -m hymotion.prompt_engineering.prompt_rewrite + + logging.basicConfig(level=logging.INFO) + text = "person jumps after they runs" + prompt_rewriter = PromptRewriter(backend="our_rewriter") + result = prompt_rewriter.rewrite_prompt_and_infer_time(text) + print(result) diff --git a/hymotion/utils/configs.py b/hymotion/utils/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e384ae1a38495e162bf009f7a30178c7732fd574 --- /dev/null +++ b/hymotion/utils/configs.py @@ -0,0 +1,344 @@ +import ast +import copy +import os.path as osp +import platform +import re +import shutil +import sys +import tempfile +import types +import uuid +from importlib import import_module +from pathlib import Path +from typing import Any, Dict, Iterator, NoReturn, Optional, Union +import yaml + +from .misc import import_modules_from_strings +from .path import check_file_exist + +BASE_KEY = "_base_" +DELETE_KEY = "_delete_" +RESERVED_KEYS = ["filename", "text", "pretty_text"] + + +class Config: + def __init__( + self, + cfg_dict: Optional[dict] = None, + cfg_text: Optional[str] = None, + filename: Optional[str] = None, + ) -> None: + if cfg_dict is None: + cfg_dict = dict() + elif not isinstance(cfg_dict, dict): + raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") + for key in cfg_dict: + if key in RESERVED_KEYS: + raise KeyError(f"{key} is reserved for config file") + + if isinstance(filename, Path): + filename = str(filename) + + super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) + super(Config, self).__setattr__("_filename", filename) + if cfg_text: + text = cfg_text + elif filename: + with open(filename, "r") as f: + text = f.read() + else: + text = "" + super(Config, self).__setattr__("_text", text) + + @staticmethod + def fromfile( + filename: str, + use_predefined_variables: bool = True, + import_custom_modules: bool = True, + ) -> "Config": + if isinstance(filename, Path): + filename = str(filename) + cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) + if import_custom_modules and cfg_dict.get("custom_imports", None): + import_modules_from_strings(**cfg_dict["custom_imports"]) + return Config(cfg_dict, cfg_text=cfg_text, filename=filename) + + @staticmethod + def _file2dict(filename: str, use_predefined_variables: bool = True) -> tuple[dict, str]: + filename = osp.abspath(osp.expanduser(filename)) + check_file_exist(filename) + fileExtname = osp.splitext(filename)[1] + if fileExtname not in [".py"]: + raise IOError("Only py type are supported now!") + + cfg_dict = {} + + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=fileExtname) + if platform.system() == "Windows": + temp_config_file.close() + temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars(temp_config_file.name, temp_config_file.name) + + if filename.endswith(".py"): + temp_module_name = osp.splitext(temp_config_name)[0] + sys.path.insert(0, temp_config_dir) + Config._validate_py_syntax(filename) + mod = import_module(temp_module_name) + sys.path.pop(0) + cfg_dict = { + name: value + for name, value in mod.__dict__.items() + if not name.startswith("__") + and not isinstance(value, types.ModuleType) + and not isinstance(value, types.FunctionType) + } + # delete imported module + del sys.modules[temp_module_name] + + # close temp file + temp_config_file.close() + + cfg_text = filename + "\n" + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + cfg_text += f.read() + + if BASE_KEY in cfg_dict: + cfg_dir = osp.dirname(filename) + base_filename = cfg_dict.pop(BASE_KEY) + base_filename = base_filename if isinstance(base_filename, list) else [base_filename] + + cfg_dict_list = list() + cfg_text_list = list() + for f in base_filename: + _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) + cfg_dict_list.append(_cfg_dict) + cfg_text_list.append(_cfg_text) + + base_cfg_dict = dict() + for c in cfg_dict_list: + duplicate_keys = base_cfg_dict.keys() & c.keys() + if len(duplicate_keys) > 0: + raise KeyError("Duplicate key is not allowed among bases. " f"Duplicate keys: {duplicate_keys}") + base_cfg_dict.update(c) + + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, base_cfg_dict) + assert isinstance(cfg_dict, dict) + + base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = base_cfg_dict + + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = "\n".join(cfg_text_list) + + return cfg_dict, cfg_text + + @staticmethod + def _validate_py_syntax(filename: str) -> None: + with open(filename, "r", encoding="utf-8") as f: + # Setting encoding explicitly to resolve coding issue on windows + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError("There are syntax errors in config " f"file {filename}: {e}") + + @staticmethod + def _pre_substitute_base_vars(filename: str, temp_config_name: str) -> dict: + """Substitute base variable placehoders to string, so that parsing would work.""" + with open(filename, "r", encoding="utf-8") as f: + config_file = f.read() + base_var_dict = {} + regexp = r"\{\{\s*" + BASE_KEY + r"\.([\w\.]+)\s*\}\}" + base_vars = set(re.findall(regexp, config_file)) + for base_var in base_vars: + randstr = f"_{base_var}_{uuid.uuid4().hex.lower()[:6]}" + base_var_dict[randstr] = base_var + regexp = r"\{\{\s*" + BASE_KEY + r"\." + base_var + r"\s*\}\}" + config_file = re.sub(regexp, f'"{randstr}"', config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + return base_var_dict + + @staticmethod + def _substitute_base_vars( + cfg: Union[dict, list, tuple, str], + base_var_dict: dict, + base_cfg: dict, + ) -> Union[dict, list, tuple, str]: + """Substitute variable strings to their actual values.""" + cfg = copy.deepcopy(cfg) + + if isinstance(cfg, dict): + for k, v in cfg.items(): + if isinstance(v, str) and v in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[v].split("."): + new_v = new_v[new_k] + cfg[k] = new_v + elif isinstance(v, (list, tuple, dict)): + cfg[k] = Config._substitute_base_vars(v, base_var_dict, base_cfg) + elif isinstance(cfg, tuple): + cfg = tuple(Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg) + elif isinstance(cfg, list): + cfg = [Config._substitute_base_vars(c, base_var_dict, base_cfg) for c in cfg] + elif isinstance(cfg, str) and cfg in base_var_dict: + new_v = base_cfg + for new_k in base_var_dict[cfg].split("."): + new_v = new_v[new_k] + cfg = new_v + + return cfg + + @staticmethod + def _substitute_predefined_vars(filename: str, temp_config_name: str) -> None: + file_dirname = osp.dirname(filename) + file_basename = osp.basename(filename) + file_basename_no_extension = osp.splitext(file_basename)[0] + file_extname = osp.splitext(filename)[1] + support_templates = dict( + fileDirname=file_dirname, + fileBasename=file_basename, + fileBasenameNoExtension=file_basename_no_extension, + fileExtname=file_extname, + ) + with open(filename, "r", encoding="utf-8") as f: + config_file = f.read() + for key, value in support_templates.items(): + regexp = r"\{\{\s*" + str(key) + r"\s*\}\}" + value = value.replace("\\", "/") + config_file = re.sub(regexp, value, config_file) + with open(temp_config_name, "w", encoding="utf-8") as tmp_config_file: + tmp_config_file.write(config_file) + + @staticmethod + def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict: + b = b.copy() + for k, v in a.items(): + if allow_list_keys and k.isdigit() and isinstance(b, list): + k = int(k) + if len(b) <= k: + raise KeyError(f"Index {k} exceeds the length of list {b}") + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + elif isinstance(v, dict): + if k in b and not v.pop(DELETE_KEY, False): + allowed_types = (dict, list) if allow_list_keys else dict + if not isinstance(b[k], allowed_types): + raise TypeError( + f"{k}={v} in child config cannot inherit from " + f"base because {k} is a dict in the child config " + f"but is of type {type(b[k])} in base config. " + f"You may set `{DELETE_KEY}=True` to ignore the " + f"base config." + ) + b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) + else: + b[k] = ConfigDict(v) + else: + b[k] = v + return b + + def to_dict(self) -> Any: + def convert_configdict(obj): + if isinstance(obj, ConfigDict): + return {k: convert_configdict(v) for k, v in obj.items()} + elif isinstance(obj, dict): + return {k: convert_configdict(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [convert_configdict(item) for item in obj] + else: + return obj + + return convert_configdict(self._cfg_dict) + + @classmethod + def from_dict(cls, cfg_dict: dict, filename: Optional[str] = None) -> "Config": + return cls(cfg_dict=cfg_dict, filename=filename) + + def save_yaml(self, filename: str) -> None: + with open(filename, "w", encoding="utf-8") as f: + yaml.safe_dump(self.to_dict(), f, default_flow_style=False, indent=2) + + @classmethod + def load_yaml(cls, filename: str) -> "Config": + with open(filename, "r", encoding="utf-8") as f: + cfg_dict = yaml.safe_load(f) + return cls.from_dict(cfg_dict, filename=filename) + + def __repr__(self) -> str: + return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" + + def __len__(self) -> int: + return len(self._cfg_dict) + + def __getattr__(self, name: str) -> Any: + return getattr(self._cfg_dict, name) + + def __getitem__(self, name: str) -> Any: + return self._cfg_dict.__getitem__(name) + + def __setattr__(self, name: str, value: Any) -> None: + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setattr__(name, value) + + def __setitem__(self, name: str, value: Any) -> None: + if isinstance(value, dict): + value = ConfigDict(value) + self._cfg_dict.__setitem__(name, value) + + def __iter__(self) -> Iterator[Any]: + return iter(self._cfg_dict) + + def __getstate__(self) -> tuple[dict, str, str]: + return (self._cfg_dict, self._filename, self._text) + + def __copy__(self) -> "Config": + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + + return other + + def __deepcopy__(self, memo: dict) -> "Config": + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) + + return other + + +class ConfigDict(Dict): + def __missing__(self, name: str) -> NoReturn: + raise KeyError(name) + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def to_dict(self) -> Any: + def convert_configdict(obj): + if isinstance(obj, ConfigDict): + return {k: convert_configdict(v) for k, v in obj.items()} + elif isinstance(obj, dict): + return {k: convert_configdict(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [convert_configdict(item) for item in obj] + else: + return obj + + return convert_configdict(dict(self)) diff --git a/hymotion/utils/geometry.py b/hymotion/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..b93f831ba1eccafffddb25bb27e644f9a0944a67 --- /dev/null +++ b/hymotion/utils/geometry.py @@ -0,0 +1,856 @@ +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor + + +def rotation_6d_to_matrix(d6: Tensor) -> Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def matrix_to_rotation_6d(matrix: Tensor) -> Tensor: + """ + Converts rotation matrices to 6D rotation representation by Zhou et al. [1] + by dropping the last row. Note that 6D representation is not unique. + Args: + matrix: batch of rotation matrices of size (*, 3, 3) + + Returns: + 6D rotation representation, of size (*, 6) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + + +def standardize_quaternion(quaternions: Tensor) -> Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def _sqrt_positive_part(x: Tensor) -> Tensor: + """Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0.""" + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def matrix_to_quaternion(matrix: Tensor) -> Tensor: + """Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def quaternion_to_axis_angle(quaternions: Tensor) -> Tensor: + """Convert rotations given as quaternions to axis/angle. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) + half_angles = torch.atan2(norms, quaternions[..., :1]) + angles = 2 * half_angles + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + return quaternions[..., 1:] / sin_half_angles_over_angles + + +def matrix_to_axis_angle(matrix: Tensor) -> Tensor: + """Convert rotations given as rotation matrices to axis/angle. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + Rotations given as a vector in axis angle form, as a tensor + of shape (..., 3), where the magnitude is the angle + turned anticlockwise in radians around the vector's + direction. + """ + return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) + + +def quaternion_to_matrix(quaternions: Tensor) -> Tensor: + """Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def axis_angle_to_quaternion(axis_angle: Tensor) -> Tensor: + """Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles] + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1) + return quaternions + + +def axis_angle_to_matrix(axis_angle: Tensor) -> Tensor: + """Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) + + +def get_T_w2c_from_wcparams( + global_orient_w: Tensor, transl_w: Tensor, global_orient_c: Tensor, transl_c: Tensor, offset: Tensor +) -> Tensor: + """ + Args: + global_orient_w: Tensor, (F, 3) + transl_w: Tensor, (F, 3) + global_orient_c: Tensor, (F, 3) + transl_c: Tensor, (F, 3) + offset: Tensor, (*, 3) + Returns: + T_w2c: Tensor, (F, 4, 4) + """ + assert global_orient_w.shape == transl_w.shape and len(global_orient_w.shape) == 2 + assert global_orient_c.shape == transl_c.shape and len(global_orient_c.shape) == 2 + + R_w = axis_angle_to_matrix(global_orient_w) # (F, 3, 3) + t_w = transl_w # (F, 3) + R_c = axis_angle_to_matrix(global_orient_c) # (F, 3, 3) + t_c = transl_c # (F, 3) + + R_w2c = R_c @ R_w.transpose(-1, -2) # (F, 3, 3) + t_w2c = t_c + offset - torch.einsum("fij,fj->fi", R_w2c, t_w + offset) # (F, 3) + T_w2c = torch.eye(4, device=global_orient_w.device).repeat(R_w.size(0), 1, 1) # (F, 4, 4) + T_w2c[..., :3, :3] = R_w2c # (F, 3, 3) + T_w2c[..., :3, 3] = t_w2c # (F, 3) + return T_w2c + + +def get_R_c2gv(R_w2c, axis_gravity_in_w=[0, 0, -1]): + """ + Args: + R_w2c: (*, 3, 3) + Returns: + R_c2gv: (*, 3, 3) + """ + if isinstance(axis_gravity_in_w, list): + axis_gravity_in_w = torch.tensor(axis_gravity_in_w).float() # gravity direction in world coord + axis_z_in_c = torch.tensor([0, 0, 1]).float() + + # get gv-coord axes in in c-coord + axis_y_of_gv = R_w2c @ axis_gravity_in_w # (*, 3) + axis_x_of_gv = axis_y_of_gv.cross(axis_z_in_c.expand_as(axis_y_of_gv), dim=-1) + # normalize + axis_x_of_gv_norm = axis_x_of_gv.norm(dim=-1, keepdim=True) + axis_x_of_gv = axis_x_of_gv / (axis_x_of_gv_norm + 1e-5) + axis_x_of_gv[axis_x_of_gv_norm.squeeze(-1) < 1e-5] = torch.tensor([1.0, 0.0, 0.0]) # use cam x-axis as axis_x_of_gv + axis_z_of_gv = axis_x_of_gv.cross(axis_y_of_gv, dim=-1) + + R_gv2c = torch.stack([axis_x_of_gv, axis_y_of_gv, axis_z_of_gv], dim=-1) # (*, 3, 3) + R_c2gv = R_gv2c.transpose(-1, -2) # (*, 3, 3) + return R_c2gv + + +def get_c_rootparam(global_orient: Tensor, transl: Tensor, T_w2c: Tensor, offset: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + global_orient: Tensor, (F, 3) + transl: Tensor, (F, 3) + T_w2c: Tensor, (*, 4, 4) + offset: Tensor, (3,) + Returns: + R_c: Tensor, (F, 3) + t_c: Tensor, (F, 3) + """ + assert global_orient.shape == transl.shape and len(global_orient.shape) == 2 + R_w = axis_angle_to_matrix(global_orient) # (F, 3, 3) + t_w = transl # (F, 3) + + R_w2c = T_w2c[..., :3, :3] # (*, 3, 3) + t_w2c = T_w2c[..., :3, 3] # (*, 3) + if len(R_w2c.shape) == 2: + R_w2c = R_w2c[None].expand(R_w.size(0), -1, -1) # (F, 3, 3) + t_w2c = t_w2c[None].expand(t_w.size(0), -1) + + R_c = matrix_to_axis_angle(R_w2c @ R_w) # (F, 3) + t_c = torch.einsum("fij,fj->fi", R_w2c, t_w + offset) + t_w2c - offset # (F, 3) + return R_c, t_c + + +def compute_cam_angvel(R_w2c, padding_last=True): + """ + R_w2c : (F, 3, 3) + """ + # R @ R0 = R1, so R = R1 @ R0^T + cam_angvel = matrix_to_rotation_6d(R_w2c[1:] @ R_w2c[:-1].transpose(-1, -2)) # (F-1, 6) + # cam_angvel = (cam_angvel - torch.tensor([[1, 0, 0, 0, 1, 0]])) * FPS + assert padding_last + cam_angvel = torch.cat([cam_angvel, cam_angvel[-1:]], dim=0) # (F, 6) + return cam_angvel.float() + + +def rot6d_to_rotation_matrix(rot6d): + """Convert 6D rotation representation to 3x3 rotation matrix. + + Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 + Args: + rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations. + Returns: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices. + """ + # x = rot6d.view(-1, 3, 2) + x = rot6d.view(*rot6d.shape[:-1], 3, 2) + a1 = x[..., 0] + a2 = x[..., 1] + b1 = F.normalize(a1, dim=-1) + b2 = F.normalize(a2 - torch.einsum("...i,...i->...", b1, a2).unsqueeze(-1) * b1, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-1) + + +def rotation_matrix_to_rot6d(rotation_matrix): + """Convert 3x3 rotation matrix to 6D rotation representation. + + Args: + rotation_matrix: torch tensor of shape (batch_size, 3, 3) of corresponding rotation matrices. + Returns: + rot6d: torch tensor of shape (batch_size, 6) of 6d rotation representations. + """ + v1 = rotation_matrix[..., 0:1] + v2 = rotation_matrix[..., 1:2] + rot6d = torch.cat([v1, v2], dim=-1).reshape(*v1.shape[:-2], 6) + return rot6d + + +def quaternion_to_rotation_matrix(quaternion): + """Convert quaternion coefficients to rotation matrix. + + Args: + quaternion: torch tensor of shape (batch_size, 4) in (w, x, y, z) representation. + Returns: + rotation matrix corresponding to the quaternion, torch tensor of shape (batch_size, 3, 3) + """ + + norm_quaternion = quaternion + norm_quaternion = norm_quaternion / norm_quaternion.norm(p=2, dim=-1, keepdim=True) + w, x, y, z = norm_quaternion[..., 0], norm_quaternion[..., 1], norm_quaternion[..., 2], norm_quaternion[..., 3] + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotation_matrix = torch.stack( + [ + w2 + x2 - y2 - z2, + 2 * xy - 2 * wz, + 2 * wy + 2 * xz, + 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, + 2 * xz - 2 * wy, + 2 * wx + 2 * yz, + w2 - x2 - y2 + z2, + ], + dim=-1, + ) + rotation_matrix = rotation_matrix.view(*quaternion.shape[:-1], 3, 3) + return rotation_matrix + + +def quaternion_to_angle_axis(quaternion: Tensor) -> Tensor: + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (Tensor): tensor with quaternions. + + Return: + Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)) + # unpack input and compute conversion + q1: torch.Tensor = quaternion[..., 1] + q2: torch.Tensor = quaternion[..., 2] + q3: torch.Tensor = quaternion[..., 3] + sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta) + cos_theta: torch.Tensor = quaternion[..., 0] + two_theta: torch.Tensor = 2.0 * torch.where( + cos_theta < 0.0, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta) + ) + + k_pos: torch.Tensor = two_theta / sin_theta + k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta) + k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError("Input size must be a three dimensional tensor. Got {}".format(rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + hom = ( + torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device) + .reshape(1, 3, 1) + .expand(rotation_matrix.shape[0], -1, -1) + ) + rotation_matrix = torch.cat([rotation_matrix, hom], dim=-1) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack( + [rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], + -1, + ) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack( + [rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], + -1, + ) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack( + [rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], + -1, + ) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack( + [t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], + -1, + ) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * ~mask_d0_d1 + mask_c2 = ~mask_d2 * mask_d0_nd1 + mask_c3 = ~mask_d2 * ~mask_d0_nd1 + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa # noqa + q *= 0.5 + return q + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """ + This function is borrowed from https://github.com/kornia/kornia + + Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + origin_shape = rotation_matrix.shape[:-2] + flat_rot = rotation_matrix.reshape(-1, *rotation_matrix.shape[-2:]) + if flat_rot.shape[1:] == (3, 3): + rot_mat = flat_rot + hom = ( + torch.tensor([0, 0, 1], dtype=rotation_matrix.dtype, device=rotation_matrix.device) + .reshape(1, 3, 1) + .expand(rot_mat.shape[0], -1, -1) + ) + flat_rot = torch.cat([rot_mat, hom], dim=-1) + + quaternion = rotation_matrix_to_quaternion(flat_rot) + aa = quaternion_to_angle_axis(quaternion) + aa[torch.isnan(aa)] = 0.0 + aa = aa.reshape(*origin_shape, 3) + return aa + + +def quat_to_rotmat(quat): + """Convert quaternion coefficients to rotation matrix. + + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack( + [ + w2 + x2 - y2 - z2, + 2 * xy - 2 * wz, + 2 * wy + 2 * xz, + 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, + 2 * yz - 2 * wx, + 2 * xz - 2 * wy, + 2 * wx + 2 * yz, + w2 - x2 - y2 + z2, + ], + dim=1, + ).view(B, 3, 3) + return rotMat + + +def angle_axis_to_rotation_matrix(theta): + """Convert axis-angle representation to rotation matrix. + + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + origin_shape = theta.shape[:-1] + flat_theta = theta.reshape(-1, 3) + l1norm = torch.norm(flat_theta + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(flat_theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim=1) + rot_mat = quat_to_rotmat(quat) + return rot_mat.reshape(*origin_shape, 3, 3) + + +def rotation_matrix_to_euler_angles(rotation_matrix): + """Convert 3x3 rotation matrix to Euler angles.""" + is_torch = False + if isinstance(rotation_matrix, Tensor): + is_torch = True + device = rotation_matrix.device + rotation_matrix = rotation_matrix.cpu().numpy() + from scipy.spatial.transform import Rotation + + rot_flat = rotation_matrix.reshape(-1, 3, 3) + euler_angles = Rotation.from_matrix(rot_flat).as_euler("xyz", degrees=True) + if is_torch: + return torch.from_numpy(euler_angles).to(device) + return euler_angles + + +def euler_angles_to_rotation_matrix(euler_angles, degrees=True): + """Convert Euler angles to 3x3 rotation matrix. + + Args: + euler_angles: Euler angles in xyz order, shape = [B, 3] or any shape with last dimension 3 + degrees: Whether the angles are in degrees (True) or radians (False) + + Returns: + Rotation matrix corresponding to the Euler angles, shape = [..., 3, 3] + """ + from scipy.spatial.transform import Rotation + + orig_shape = euler_angles.shape[:-1] + euler_flat = euler_angles.reshape(-1, 3) + rot_flat = Rotation.from_euler("xyz", euler_flat, degrees=degrees).as_matrix() + return rot_flat.reshape(*orig_shape, 3, 3) + + +def get_local_transl_vel(transl, global_orient_R, fps=30): + """ + transl velocity is in local coordinate (or, SMPL-coord) + Args: + transl: (*, L, 3) + global_orient: (*, L, 3, 3) + Returns: + transl_vel: (*, L, 3) + """ + transl_vel = transl[..., 1:, :] - transl[..., :-1, :] # (B, L-1, 3) + transl_vel = torch.cat([torch.zeros_like(transl_vel[:1]), transl_vel], dim=-2) # (B, L, 3) last-padding + transl_vel = transl_vel * fps + + # v_local = R^T @ v_global + local_transl_vel = torch.einsum("...lij,...li->...lj", global_orient_R, transl_vel) + return local_transl_vel + + +def compute_transl_full_cam(pred_cam, bbx_xys, K_fullimg): + s, tx, ty = pred_cam[..., 0], pred_cam[..., 1], pred_cam[..., 2] + focal_length = K_fullimg[..., 0, 0] + + icx = K_fullimg[..., 0, 2] + icy = K_fullimg[..., 1, 2] + sb = s * bbx_xys[..., 2] + cx = 2 * (bbx_xys[..., 0] - icx) / (sb + 1e-9) + cy = 2 * (bbx_xys[..., 1] - icy) / (sb + 1e-9) + tz = 2 * focal_length / (sb + 1e-9) + + cam_t = torch.stack([tx + cx, ty + cy, tz], dim=-1) + return cam_t + + +def quaternion_fix_continuity(q: Tensor) -> Tensor: + """Force quaternion continuity across the time dimension by selecting the representation (q or -q) with minimal + distance (or, equivalently, maximal dot product) between two consecutive frames.""" + assert q.ndim in ( + 2, + 3, + ), f"Expected 3D tensor (L, J, 4), or 2D tensor (L, 4), but got shape {q.shape}" + assert q.shape[-1] == 4, f"Last dimension should be 4 for quaternions, got {q.shape[-1]}" + if q.shape[0] <= 1: + return q.clone() # single frame or empty sequence, no need to process + + result = q.clone() + # compute the dot product between consecutive frames (L-1, J) or (L-1) + dot_products = torch.sum(q[1:] * q[:-1], dim=-1) + # find the negative dot product (indicates need to flip sign) + flip_mask = dot_products < 0 + # accumulate the flip mask, ensure consistency + # if a frame needs to be flipped, all subsequent frames need to be flipped the same number of times + flip_mask = (torch.cumsum(flip_mask.int(), dim=0) % 2).bool() + # flip the sign of the frames that need to be flipped + result[1:][flip_mask] *= -1 + return result + + +def rot_mat2trans_mat(rot_mat: np.ndarray) -> np.ndarray: + # assert rot_mat.shape == (3, 3) + trans_mat = np.identity(4) + if len(rot_mat.shape) == 2: + trans_mat = trans_mat + elif len(rot_mat.shape) == 3: + trans_mat = np.tile(trans_mat, [rot_mat.shape[0], 1, 1]) + elif len(rot_mat.shape) == 4: + trans_mat = np.tile(trans_mat, [rot_mat.shape[0], rot_mat.shape[1], 1, 1]) + else: + raise NotImplementedError + trans_mat[..., :3, :3] = rot_mat + return trans_mat + + +def trans2trans_mat(trans: np.ndarray) -> np.ndarray: + assert trans.shape[-1] == 3 + assert (len(trans.shape) == 1) or (len(trans.shape) == 2) or (len(trans.shape) == 3), trans.shape + if len(trans.shape) == 1: + trans_mat = np.identity(4) + trans_mat[:3, 3] = trans + elif len(trans.shape) == 2: + trans_mat = np.tile(np.identity(4), [trans.shape[0], 1, 1]) + trans_mat[:, :3, 3] = trans + elif len(trans.shape) == 3: + trans_mat = np.tile(np.identity(4), [trans.shape[0], trans.shape[1], 1, 1]) + trans_mat[:, :, :3, 3] = trans + else: + raise NotImplementedError + return trans_mat + + +def gaussian_kernel1d(sigma: float, order: int, radius: int) -> np.ndarray: + """Computes a 1D Gaussian convolution kernel. + + (from scipy) + """ + if order < 0: + raise ValueError("order must be non-negative") + exponent_range = np.arange(order + 1) + sigma2 = sigma * sigma + x = np.arange(-radius, radius + 1) + phi_x = np.exp(-0.5 / sigma2 * x**2) + phi_x = phi_x / phi_x.sum() + + if order == 0: + return phi_x + else: + # f(x) = q(x) * phi(x) = q(x) * exp(p(x)) + # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x) + # p'(x) = -1 / sigma ** 2 + # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the + # coefficients of q(x) + q = np.zeros(order + 1) + q[0] = 1 + D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x) + P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x) + Q_deriv = D + P + for _ in range(order): + q = Q_deriv.dot(q) + q = (x[:, None] ** exponent_range).dot(q) + return q * phi_x + + +def slice_seq_with_padding(whole_seq: np.ndarray, middle_idx: int, length: int) -> np.ndarray: + whole_seq_padded = whole_seq.copy() + if middle_idx - length // 2 < 0: + # need padding + l_pad_len = length // 2 - middle_idx + whole_seq_padded = np.concatenate([np.stack([whole_seq_padded[0]] * l_pad_len), whole_seq_padded], axis=0) + else: + l_pad_len = 0 + if middle_idx + length - length // 2 > len(whole_seq): + r_pad_len = middle_idx + length - length // 2 - len(whole_seq) + whole_seq_padded = np.concatenate([whole_seq_padded, np.stack([whole_seq_padded[-1]] * r_pad_len)], axis=0) + else: + r_pad_len = 0 + assert len(whole_seq_padded) == len(whole_seq) + l_pad_len + r_pad_len + middle_idx_padded = middle_idx + l_pad_len + assert middle_idx_padded - length // 2 >= 0 + assert middle_idx_padded + length - length // 2 <= len(whole_seq_padded) + return whole_seq_padded[middle_idx_padded - length // 2 : middle_idx_padded - length // 2 + length] + + +def wavg_quaternion_markley(Q: np.ndarray, weights: np.ndarray) -> np.ndarray: + """ + Averaging Quaternions. + This is a python implementation of Tolga Birdal's algorithm by https://stackoverflow.com/a/49690919 + + Arguments: + Q(ndarray): an Mx4 ndarray of quaternions. + weights(list): an M elements list, a weight for each quaternion. + + refer to Tolga Birdal's matlab implementation on + https://ww2.mathworks.cn/matlabcentral/fileexchange/40098-tolgabirdal-averaging_quaternions?s_tid=prof_contriblnk&s_tid=mwa_osa_a + by Tolga Birdal + Q is an Mx4 matrix of quaternions. weights is an Mx1 vector, a weight for + each quaternion. + Qavg is the weighted average quaternion + This function is especially useful for example when clustering poses + after a matching process. In such cases a form of weighting per rotation + is available (e.g. number of votes), which can guide the trust towards a + specific pose. weights might then be interpreted as the vector of votes + per pose. + Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman. + "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30, + no. 4 (2007): 1193-1197. + """ + + # Form the symmetric accumulator matrix + # pdb.set_trace() + A = np.zeros((4, 4)) + M = Q.shape[0] + wSum = 0 + + for i in range(M): + q = Q[i, :] + w_i = weights[i] + if q[0] < 0: + # handle the antipodal configuration + q = -q + A += w_i * (np.outer(q, q)) # rank 1 update + wSum += w_i + + # scale + A /= wSum + + # Get the eigenvector corresponding to largest eigen value + return np.linalg.eigh(A)[1][:, -1] diff --git a/hymotion/utils/loaders.py b/hymotion/utils/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..e407876c55826f96b4b9b03d289535f24793aab1 --- /dev/null +++ b/hymotion/utils/loaders.py @@ -0,0 +1,184 @@ +import importlib +import json +import os + + +def load_object(module_name, module_args, **extra_args): + module_args = module_args.copy() + module_path = ".".join(module_name.split(".")[:-1]).replace("/", ".") + module = importlib.import_module(module_path) + name = module_name.split(".")[-1] + if module_args is None: + module_args = {} + module_args.update(extra_args) + obj = getattr(module, name)(**module_args) + return obj + + +def load_module(module_name): + module_path = module_name.split(".")[0].replace("/", ".") + module = importlib.import_module(module_path) + name = module_name.split(".")[-1] + obj = getattr(module, name) + return obj + + +def check_cfg(cfg, global_dict, verbose=True): + for key, val in cfg.items(): + if isinstance(val, dict): + check_cfg(val, global_dict, verbose) + elif isinstance(val, str): + if val.startswith("$"): + if verbose: + print(f" - Update {key} with {val} = {global_dict[val[1:]]}") + cfg[key] = global_dict[val[1:]] + + +def read_yaml(yamlname): + import yaml + + with open(yamlname, "r", encoding="utf-8") as file: + try: + data = yaml.safe_load(file) + except yaml.constructor.ConstructorError: + file.seek(0) + data = yaml.load(file, Loader=yaml.FullLoader) + if hasattr(data, "to_dict"): + data = data.to_dict() + elif hasattr(data, "_cfg_dict"): + data = dict(data._cfg_dict) + + return data + + +def write_yaml(data, yamlname): + import yaml + + with open(yamlname, "w", encoding="utf-8") as file: + yaml.dump(data, file) + + +def check_input(data, verbose=True): + data_parent = {} + if "input" in data: + if verbose: + print(" - Check input file list") + for filename in data.pop("input"): + cfg_new = read_yaml(filename) + data_parent.update(cfg_new) + return data_parent + + +def merge_dict(dict_A, dict_B, key, verbose=True): + if isinstance(dict_A[key], dict): + dict_B = dict_B.copy() + for key2, val2 in dict_A[key].items(): + if key2 in dict_B[key]: + merge_dict(dict_A[key], dict_B[key], key2, verbose) + dict_B[key].pop(key2) + if len(dict_B[key]) > 0: + if verbose: + print(f" - Create {key} with {dict_B[key]}") + for key2, val2 in dict_B[key].items(): + dict_A[key][key2] = val2 + else: + if verbose: + print(f" - Update {key} with {dict_B[key]}") + dict_A[key] = dict_B[key] + + +def read_config(cfgname, verbose=True): + data_base = read_yaml(cfgname) + data_parent = check_input(data_base, verbose) + # merge the data_base to data_parent + for key, val in data_parent.items(): + if key in data_base: + merge_dict(data_parent, data_base, key, verbose) + if verbose: + print(data_parent[key]) + data_base.pop(key) + data_parent.update(data_base) + data = data_parent + check_cfg(data, data, verbose) + return data + + +def update_config(config, args): + for key, value in vars(args).items(): + if key in config.keys() and value is not None: + config[key] = value + + +def read_yaml_full(path): + import yaml + + with open(path, "r") as f: + return yaml.load(f, Loader=yaml.FullLoader) + + +def check_ceph_path(path): + import os + + if os.path.exists(path): + return path + else: + raise ValueError(f"{path} not found") + + +def read_json(filename): + with open(filename, "r", encoding="utf-8") as f: + return json.load(f) + + +def write_json(data, filename): + with open(filename, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + +def load_h5_dataset(filename, ds_name_list=None, parser=None): + import h5py + + # ds for dataset + if "@" in filename: + filename, start_end = filename.split("@") + start = int(start_end.split(":")[0]) + end = int(start_end.split(":")[1]) + else: + start = None + end = None + assert os.path.isfile(filename), "cannot find: {}".format(filename) + + def load_dict(d): + ds_dict = {} + for item in d.keys(): + if ds_name_list is not None and item not in ds_name_list: + continue + if isinstance(d[item], h5py._hl.dataset.Dataset): + ds_dict[item] = d[item][()] + if parser is not None and item in parser: + ds_dict[item] = parser[item](ds_dict[item]) + elif isinstance(d[item], h5py._hl.group.Group): + ds_dict[item] = load_dict(d[item]) + for item in d.attrs.keys(): + ds_dict[item] = d.attrs[item] + return ds_dict + + with h5py.File(filename, "r") as f: + ds_dict = load_dict(f) + for item in f.attrs.keys(): + ds_dict[item] = f.attrs[item] + if start is not None and end is not None: + for key in ["LclRotation", "LclTranslation"]: + ds_dict[key] = ds_dict[key][start:end] + return ds_dict + + +if __name__ == "__main__": + # hymotion.utils.loaders + network = load_object("hymotion.utils.base_example.ToyNetwork", {}) + print(network) + network = load_object("hymotion/utils/base_example.ToyNetwork", {}) + print(network) + load_object("diffusers.DDIMScheduler", {}) + module = load_object("torch.nn.MSELoss", {"reduction": "none"}) + print(module) diff --git a/hymotion/utils/misc.py b/hymotion/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7e745b6ef6cb08c8d46c19c8b3dc353ece97e46a --- /dev/null +++ b/hymotion/utils/misc.py @@ -0,0 +1,136 @@ +import warnings +from collections.abc import Iterable, Sequence +from importlib import import_module +from itertools import repeat +from os import path as osp +from typing import Any, Callable, Optional, Tuple, Union + + +def is_str(x: Any) -> bool: + """Whether the input is an string instance. + + Note: This method is deprecated since python 2 is no longer supported. + """ + return isinstance(x, str) + + +def is_seq_of(seq: Any, expected_type: Any, seq_type: Any = None) -> bool: + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. + Returns: + bool: Whether the sequence is valid. + """ + if seq_type is None: + exp_seq_type = Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq: Any, expected_type: Any) -> bool: + """Check whether it is a list of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=list) + + +def is_tuple_of(seq: Any, expected_type: Any) -> bool: + """Check whether it is a tuple of some type. + + A partial method of :func:`is_seq_of`. + """ + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def import_modules_from_strings( + imports: Union[list[str], str], allow_failed_imports: bool = False +) -> Optional[list[Any]]: + if not imports: + return + single_import = False + if isinstance(imports, str): + single_import = True + imports = [imports] + if not isinstance(imports, list): + raise TypeError(f"custom_imports must be a list but got type {type(imports)}") + imported = [] + for imp in imports: + if not isinstance(imp, str): + raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") + try: + imported_tmp = import_module(imp) + except ImportError: + if allow_failed_imports: + warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) + imported_tmp = None + else: + raise ImportError + imported.append(imported_tmp) + if single_import: + imported = imported[0] + return imported + + +def _ntuple(n: int) -> Callable: + def parse(x: Any) -> Tuple: + if isinstance(x, Iterable) and not isinstance(x, str): + x = tuple(x) + if len(x) == 1: + x = tuple(repeat(x[0], n)) + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) + + +def seconds_to_hmsms(seconds: float) -> tuple[int, int, int, int]: + hours, remainder = divmod(seconds, 3600) + minutes, remainder = divmod(remainder, 60) + seconds, milliseconds = divmod(remainder, 1) + milliseconds *= 1000 + return int(hours), int(minutes), int(seconds), int(milliseconds) + + +def frames_to_hmsms(frames: int, frame_rate: int = 30) -> tuple[int, int, int, int]: + seconds = frames / frame_rate + return seconds_to_hmsms(seconds) + + +def make_series( + data_root: str, + series_name: str, + count: int, + date: str, + postfix: str = "raw_caption/", +): + return { + f"{series_name}_packed{i:02d}": { + "input_text_path": [ + osp.join( + data_root, + series_name, + f"{series_name}_packed{i:02d}", + date, + f"{postfix}", + ) + ] + } + for i in range(count) + } diff --git a/hymotion/utils/motion_process.py b/hymotion/utils/motion_process.py new file mode 100644 index 0000000000000000000000000000000000000000..baba6f8b7fdd5c9b8a991251f66fab10a3a70b03 --- /dev/null +++ b/hymotion/utils/motion_process.py @@ -0,0 +1,154 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor + + +def _hysteresis_and_morph( + prob: Tensor, + on_thr: float = 0.7, + off_thr: float = 0.5, + morph_min_len: int = 3, + morph_max_gap: int = 2, +) -> Tensor: + L, K = prob.shape + device = prob.device + contact = torch.zeros_like(prob, dtype=torch.bool) + prev = torch.zeros((K,), dtype=torch.bool, device=device) + for t in range(L): + on = prob[t] > on_thr + off = prob[t] < off_thr + prev = torch.where(on, torch.ones_like(prev, dtype=torch.bool), prev) + prev = torch.where(off, torch.zeros_like(prev, dtype=torch.bool), prev) + contact[t] = prev + + def morph_clean(x: Tensor, min_len: int = 3, max_gap: int = 2) -> Tensor: + x = x.clone() + cnt = 0 + for tt in range(L): + if x[tt]: + cnt += 1 + if (not x[tt]) or tt == L - 1: + if 0 < cnt < min_len: + x[tt - cnt : tt] = False + cnt = 0 + gap = 0 + last_on = -1 + for tt in range(L): + if x[tt]: + if 0 < gap <= max_gap and last_on >= 0: + x[last_on + 1 : tt] = True + last_on = tt + gap = 0 + else: + gap += 1 + return x + + return torch.stack( + [morph_clean(contact[:, j], morph_min_len, morph_max_gap) for j in range(K)], + dim=1, + ) + + +def correct_translation_with_contact( + k3d: Tensor, + transl: Tensor, + prob: Tensor, + joint_ids: List[int] = [7, 10, 8, 11], + on_thr: float = 0.50, + off_thr: float = 0.30, + morph_min_len: int = 3, + morph_max_gap: int = 2, + eps: float = 1e-8, +) -> Tensor: + if k3d.dim() == 3: # (L, J, 3) -> (1, L, J, 3) + k3d = k3d.unsqueeze(0) + if transl.dim() == 2: # (L, 3) -> (1, L, 3) + transl = transl.unsqueeze(0) + B, L, J, _ = k3d.shape + K = len(joint_ids) + if prob.dim() == 2: # (L, K) + contact = _hysteresis_and_morph(prob, on_thr, off_thr, morph_min_len, morph_max_gap) # (L, K) + contact = contact.unsqueeze(0).expand(B, -1, -1) # (B, L, K) + prob_b = prob.unsqueeze(0).expand(B, -1, -1) # (B, L, K) + elif prob.dim() == 3: # (B, L, K) + contact_list = [] + prob_b = prob + for b in range(prob.shape[0]): + contact_list.append(_hysteresis_and_morph(prob[b], on_thr, off_thr, morph_min_len, morph_max_gap)) + contact = torch.stack(contact_list, dim=0) # (B, L, K) + else: + raise ValueError("prob must be (L,K) or (B,L,K)") + pair_contact = contact[:, 1:] & contact[:, :-1] # (B, L-1, K) + pred_j3d_static = k3d[:, :, joint_ids, :] # (B, L, K, 3) + pred_j3d_static_disp = pred_j3d_static[:, 1:] - pred_j3d_static[:, :-1] # (B, L-1, K, 3) + w = 0.5 * (prob_b[:, 1:] + prob_b[:, :-1]) # (B, L-1, K) + w = w * pair_contact.float() + w_sum = w.sum(dim=2, keepdim=True).clamp_min(eps) # (B, L-1, 1) + drift = (pred_j3d_static_disp * w.unsqueeze(-1)).sum(dim=2) / w_sum # (B, L-1, 3) + drift[..., 1] = 0.0 + w_disp = transl[:, 1:] - transl[:, :-1] # (B, L-1, 3) + w_disp_new = w_disp - drift + transl_fixed = torch.zeros_like(transl) + transl_fixed[:, 0] = transl[:, 0] + transl_fixed[:, 1:] = transl_fixed[:, :1] + torch.cumsum(w_disp_new, dim=1) + return transl_fixed.squeeze(0) if transl_fixed.shape[0] == 1 else transl_fixed + + +def smooth_quats(quats: np.ndarray, sigma: float = 1.0) -> np.ndarray: + from .geometry import gaussian_kernel1d, quaternion_fix_continuity, slice_seq_with_padding, wavg_quaternion_markley + + if len(quats) == 0 or sigma <= 0: + return quats.copy() + + q_all = quaternion_fix_continuity(torch.from_numpy(quats)).numpy() + + results = q_all.copy() + truncate = 4.0 + order = 0 + lw = int(truncate * float(sigma) + 0.5) + weights = gaussian_kernel1d(sigma=sigma, order=order, radius=lw)[::-1] + kernel_len = len(weights) + + for fr in range(len(q_all)): + cur_quats = slice_seq_with_padding(q_all, fr, kernel_len) # (K,4) + ref = cur_quats[kernel_len // 2 : kernel_len // 2 + 1] # (1,4) + dots = (cur_quats * ref).sum(axis=-1, keepdims=True) # (K,1) + cur_quats = np.where(dots < 0.0, -cur_quats, cur_quats) + + results[fr, :] = wavg_quaternion_markley(cur_quats, weights) + + return results.copy() + + +def smooth_rotation( + quats: np.ndarray, + # joint_names: List[str], + # smooth_joints: List[str], + sigma: float = 1.0, +) -> np.ndarray: + from .geometry import quaternion_fix_continuity + + if quats.ndim == 4: + is_batch = True + else: + is_batch = False + quats = quats[None, ...] + for b in range(quats.shape[0]): + for j_idx in range(quats.shape[2]): + cur_quats = quats[b, :, j_idx].copy() + cur_quats_t = quaternion_fix_continuity(torch.from_numpy(cur_quats)).numpy() + quats[b, :, j_idx] = smooth_quats(cur_quats_t, sigma=sigma) + if not is_batch: + quats = quats.squeeze(0) + return quats + + +def unwrap_euler_over_time(xyz: torch.Tensor) -> torch.Tensor: + # xyz: (B, L, J, 3) + # y[t] = y[0] + cumsum(wrap(Δy)) + y = xyz.clone() + dy = torch.atan2(torch.sin(y[:, 1:] - y[:, :-1]), torch.cos(y[:, 1:] - y[:, :-1])) + y[:, 1:] = y[:, :1] + torch.cumsum(dy, dim=1) + return y diff --git a/hymotion/utils/path.py b/hymotion/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d43c6ae50c280524317711467edd2b5b69ed03 --- /dev/null +++ b/hymotion/utils/path.py @@ -0,0 +1,168 @@ +import os +import os.path as osp +import platform +from pathlib import Path +from typing import Any, Generator, List, Optional, Union + +from .misc import is_str + +if platform.system() == "Windows": + import regex as re +else: + import re + + +def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None: + if not osp.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name: str, mode: int = 0o777) -> None: + if dir_name == "": + return + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src: str, dst: str, overwrite: bool = True, **kwargs) -> None: + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def is_filepath(x: Any) -> bool: + return is_str(x) or isinstance(x, Path) + + +def scandir( + dir_path: Union[str, Path], + suffix: Optional[str] = None, + recursive: bool = False, + case_sensitive: bool = True, +) -> Generator[str, None, None]: + """Scan a directory to find the interested files. + + Args: + dir_path (str | :obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + case_sensitive (bool, optional) : If set to False, ignore the case of + suffix. Default: True. + Returns: + A generator for all the interested files with relative paths. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + if suffix is not None and not case_sensitive: + suffix = suffix.lower() if isinstance(suffix, str) else tuple(item.lower() for item in suffix) + + root = dir_path + + def _scandir( + dir_path: Union[str, Path], + suffix: Optional[str], + recursive: bool, + case_sensitive: bool, + ) -> Generator[str, None, None]: + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + _rel_path = rel_path if case_sensitive else rel_path.lower() + if suffix is None or _rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix, recursive, case_sensitive) + + return _scandir(dir_path, suffix, recursive, case_sensitive) + + +def find_files(directory, pattern, recursive=True, abspath=False) -> List[str]: + regex = re.compile(pattern) + file_list = [] + for root, _, files in os.walk(directory): + for f in files: + if regex.match(f) is not None: + file_list.append(os.path.join(root, f)) + if not recursive: + break + map_func = os.path.abspath if abspath else os.path.relpath + return list(map(map_func, sorted(file_list))) + + +def natural_keys(text: str, retoken: str = r"[a-zA-Z]*(\d+)[a-zA-Z_]*[\.].*", n: int = 1) -> Union[int, str]: + def _atoi(text: str) -> Union[int, str]: + return int(text) if text.isdigit() else text.lower() + + return _atoi(re.split(retoken, text)[n]) + + +listdirs = lambda root: [osp.join(base, d) for base, dirs, _ in os.walk(root) if dirs for d in dirs] + +listfiles = lambda root: [f for base, _, files in os.walk(root) if files for f in files] + + +def parse_dirs_and_sort( + input_dirs: Union[list, str], + suffix: str, + is_sort: bool = False, + with_prefix: bool = True, +) -> List[str]: + if isinstance(input_dirs, list): + input_dirs_list = [] + for iter_input_dir in input_dirs: + if osp.isdir(iter_input_dir): + input_dirs_list += [ + osp.join(iter_input_dir, x) if with_prefix else x + for x in scandir( + iter_input_dir, + suffix=suffix, + recursive=True, + case_sensitive=False, + ) + ] + elif osp.isfile(iter_input_dir): + if iter_input_dir.endswith(suffix): + input_dirs_list += [iter_input_dir] + else: + raise ValueError(f"Input path {iter_input_dir} is not exist.") + elif isinstance(input_dirs, str): + if osp.isdir(input_dirs): + input_dirs_list = [ + osp.join(input_dirs, x) if with_prefix else x + for x in scandir(input_dirs, suffix=suffix, recursive=True, case_sensitive=False) + ] + elif osp.isfile(input_dirs): + if input_dirs.endswith(suffix): + input_dirs_list = [input_dirs] + else: + input_dirs_list = [] + else: + raise ValueError(f"Input path {input_dirs} is not exist.") + else: + raise ValueError("Only support list or str input.") + + if is_sort: + try: + try: + input_dirs_list = sorted( + input_dirs_list, + key=lambda text: ( + natural_keys(text, retoken=r"[a-zA-Z]*(\d+)_[0-9a-zA-Z_]*[\.].*", n=1), + natural_keys(text, retoken=r"[0-9a-zA-Z]*_(\d+)[a-zA-Z_]*[\.].*", n=1), + ), + ) + except: + input_dirs_list = sorted(input_dirs_list, key=lambda text: (natural_keys(text))) + except: + input_dirs_list = sorted(input_dirs_list, key=lambda text: text) + + return input_dirs_list diff --git a/hymotion/utils/smplh2fbx.py b/hymotion/utils/smplh2fbx.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b4f3e686dbeb278dbf243bcfcf474b7dd3860b --- /dev/null +++ b/hymotion/utils/smplh2fbx.py @@ -0,0 +1,585 @@ +import glob +import os +import shutil +import sys +import tempfile + +import fbx +import numpy as np +import torch +from transforms3d.euler import mat2euler + +from .geometry import angle_axis_to_rotation_matrix, rot_mat2trans_mat, trans2trans_mat + +# yapf: disable +SMPLH_JOINT2NUM = { + "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3, + "L_Knee": 4, "R_Knee": 5, "Spine2": 6, + "L_Ankle": 7, "R_Ankle": 8, + "Spine3": 9, + "L_Foot": 10, "R_Foot": 11, + "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15, + "L_Shoulder": 16, "R_Shoulder": 17, + "L_Elbow": 18, "R_Elbow": 19, + "L_Wrist": 20, "R_Wrist": 21, + # "Jaw": 22, "L_Eye": 23, "R_Eye": 24, + "L_Index1": 22, "L_Index2": 23, "L_Index3": 24, + "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27, + "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30, + "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33, + "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36, + "R_Index1": 37, "R_Index2": 38, "R_Index3": 39, + "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42, + "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45, + "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48, + "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51, +} +# yapf: enable + + +def _parse_obj_file(obj_path): + vertices = [] + uv_coords = [] + faces = [] + uv_faces = [] + + with open(obj_path, "r") as f: + for line in f: + line = line.strip() + if line.startswith("v "): + parts = line.split() + vertices.append([float(parts[1]), float(parts[2]), float(parts[3])]) + elif line.startswith("vt "): + parts = line.split() + uv_coords.append([float(parts[1]), float(parts[2])]) + elif line.startswith("f "): + parts = line.split() + face_vertices = [] + face_uvs = [] + for part in parts[1:]: + indices = part.split("/") + face_vertices.append(int(indices[0]) - 1) + if len(indices) > 1 and indices[1]: + face_uvs.append(int(indices[1]) - 1) + + if len(face_vertices) == 3: + faces.append(face_vertices) + if len(face_uvs) == 3: + uv_faces.append(face_uvs) + + return np.array(vertices), np.array(uv_coords), np.array(faces), np.array(uv_faces) + + +def _blend_shapes(betas: torch.Tensor, shape_disps: torch.Tensor) -> torch.Tensor: + """Calculates the per vertex displacement due to the blend shapes. + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + """ + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + blend_shape = torch.einsum("bl,mkl->bmk", [betas, shape_disps]) + return blend_shape + + +def _vertices2joints(J_regressor: torch.Tensor, vertices: torch.Tensor) -> torch.Tensor: + """Calculates the 3D joint locations from the vertices. + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + """ + + return torch.einsum("bik,ji->bjk", [vertices, J_regressor]) + + +def _addSmplXMesh(fbxScene, v_posed, faces, uv_coords=None, uv_faces=None): + # Obtain a reference to the scene's root node. + rootNode = fbxScene.GetRootNode() + + # Create a new node in the scene. + geometryNode = fbx.FbxNode.Create(fbxScene, "Geometry") + rootNode.AddChild(geometryNode) + + # Create a new mesh node attribute in the scene, and + # set it as the new node's attribute + mesh = fbx.FbxMesh.Create(fbxScene, "body") + geometryNode.SetNodeAttribute(mesh) + + # Define the new mesh's control points. + # v_posed, faces = smplx['v_posed'], smplx['faces'] + v_posed = np.array(v_posed) + faces = np.array(faces) + + minValue = np.min(v_posed) + maxValue = np.max(v_posed) + # print(f"min = {minValue}, max = {maxValue}") + # print("min = {}, max = {}".format(minValue, maxValue)) + + # m = axangle2mat((1, 0, 0), np.radians(180)) + + mesh.InitControlPoints(v_posed.shape[0]) + for i in range(v_posed.shape[0]): + v = v_posed[i, :] + # v = np.matmul(m, v) + vertex = fbx.FbxVector4(v[0], v[1], v[2]) + mesh.SetControlPointAt(vertex, i) + + for i in range(faces.shape[0]): + mesh.BeginPolygon(i) + mesh.AddPolygon(faces[i, 0]) + mesh.AddPolygon(faces[i, 1]) + mesh.AddPolygon(faces[i, 2]) + mesh.EndPolygon() + + if uv_coords is not None and uv_faces is not None: + uv_layer = mesh.CreateElementUV("UVSet") + uv_layer.SetMappingMode(fbx.FbxLayerElement.EMappingMode.eByPolygonVertex) + uv_layer.SetReferenceMode(fbx.FbxLayerElement.EReferenceMode.eIndexToDirect) + + uv_array = uv_layer.GetDirectArray() + for i in range(len(uv_coords)): + uv_array.Add(fbx.FbxVector2(uv_coords[i][0], uv_coords[i][1])) + + uv_index_array = uv_layer.GetIndexArray() + for i in range(len(uv_faces)): + for j in range(3): + uv_index_array.Add(uv_faces[i][j]) + return geometryNode + + +def _addSmplXSkeleton(fbxManager, fbxScene, trans, joint2num, kintree_table): + num2joint = ["" for key in joint2num] + for key, value in joint2num.items(): + num2joint[value] = key + + # trans = np.array(trans) + + # Obtain a reference to the scene's root node. + rootNode = fbxScene.GetRootNode() + + # Create a new node in the scene. + referenceNode = fbx.FbxNode.Create(fbxScene, "Reference") + rootNode.AddChild(referenceNode) + + # Create skeletons + skeletonNodes = [] + for nth in range(len(kintree_table)): + skeleton = fbx.FbxSkeleton.Create(fbxManager, "") + skeleton.SetSkeletonType(fbx.FbxSkeleton.EType.eRoot if nth == -1 else fbx.FbxSkeleton.EType.eLimbNode) + + node = fbx.FbxNode.Create(fbxScene, num2joint[nth]) + node.SetNodeAttribute(skeleton) + + node.LclTranslation.Set(fbx.FbxDouble3(trans[nth, 0], trans[nth, 1], trans[nth, 2])) + + skeletonNodes.append(node) + + if kintree_table[nth] != -1: + skeletonNodes[kintree_table[nth]].AddChild(node) + + referenceNode.AddChild(skeletonNodes[0]) + return referenceNode, skeletonNodes + + +def _addSkiningWeight(fbxScene, lbs_weights, geometryNode, skeletonNodes): + clusters = [] + for i in range(lbs_weights.shape[1]): + cluster = fbx.FbxCluster.Create(fbxScene, "") + cluster.SetLink(skeletonNodes[i]) + cluster.SetLinkMode(fbx.FbxCluster.ELinkMode.eTotalOne) + + for j in range(lbs_weights.shape[0]): + weight = lbs_weights[j, i] + if weight > 0: + cluster.AddControlPointIndex(j, weight) + + clusters.append(cluster) + + # Now we have the Geometry and the skeleton correctly positioned, + # set the transform and TransformLink matrix accordingly. + matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(geometryNode) + for cluster in clusters: + cluster.SetTransformMatrix(matrix) + + for i in range(len(skeletonNodes)): + matrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(skeletonNodes[i]) + clusters[i].SetTransformLinkMatrix(matrix) + + # Add the clusters to the patch by creating a skin and adding those clusters to that skin. + skin = fbx.FbxSkin.Create(fbxScene, "") + for cluster in clusters: + skin.AddCluster(cluster) + geometryNode.GetNodeAttribute().AddDeformer(skin) + + +def _storeBindPose(fbxScene, geometryNode): + # In the bind pose, we must store all the link's global matrix at the + # time of the bind. + # Plus, we must store all the parent(s) global matrix of a link, even + # if they are not themselves deforming any model. + + clusteredNodes = [] + if geometryNode and geometryNode.GetNodeAttribute(): + skinCount = 0 + clusterCount = 0 + attributeType = geometryNode.GetNodeAttribute().GetAttributeType() + if attributeType in ( + fbx.FbxNodeAttribute.EType.eMesh, + fbx.FbxNodeAttribute.EType.eNurbs, + fbx.FbxNodeAttribute.EType.ePatch, + ): + skinCount = geometryNode.GetNodeAttribute().GetDeformerCount(fbx.FbxDeformer.EDeformerType.eSkin) + for i in range(skinCount): + skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin) + clusterCount += skin.GetClusterCount() + + if clusterCount: + for i in range(skinCount): + skin = geometryNode.GetNodeAttribute().GetDeformer(i, fbx.FbxDeformer.EDeformerType.eSkin) + clusterCount = skin.GetClusterCount() + for j in range(clusterCount): + link = skin.GetCluster(j).GetLink() + _addNodeRecursively(clusteredNodes, link) + + # Add the geometry to the pose + clusteredNodes += [geometryNode] + + # Now create a bind pose with the link list + if len(clusteredNodes): + # A pose must be named. Arbitrarily use the name of the geometry node. + pose = fbx.FbxPose.Create(fbxScene, geometryNode.GetName()) + pose.SetIsBindPose(True) + + for node in clusteredNodes: + bindMatrix = fbxScene.GetAnimationEvaluator().GetNodeGlobalTransform(node) + pose.Add(node, fbx.FbxMatrix(bindMatrix)) + + fbxScene.AddPose(pose) + + +def _addNodeRecursively(nodeArray, node): + """Add the specified node to the node array. + + Also, add recursively all the parent node of the specified node to the array. + """ + if node: + _addNodeRecursively(nodeArray, node.GetParent()) + found = False + if node in nodeArray: + if node.GetName() == node.GetName(): + found = True + if not found: + nodeArray += [node] + + +def _animateGlobalTransformsFromTransMat(animLayer, referenceNode, global_translation, frameDuration): + _animateSingleChannel(animLayer, referenceNode.LclTranslation, "X", global_translation, frameDuration) + _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Y", global_translation, frameDuration) + _animateSingleChannel(animLayer, referenceNode.LclTranslation, "Z", global_translation, frameDuration) + + +def _animateSingleChannel(animLayer, component, name, values, frameDuration): + ncomp = 0 + + if name == "X": + ncomp = 0 + elif name == "Y": + ncomp = 1 + elif name == "Z": + ncomp = 2 + + time = fbx.FbxTime() + curve = component.GetCurve(animLayer, name, True) + curve.KeyModifyBegin() + for nth in range(len(values)): + time.SetSecondDouble(nth * frameDuration) + keyIndex = curve.KeyAdd(time)[0] + curve.KeySetValue(keyIndex, values[nth][ncomp]) + curve.KeySetInterpolation( + keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant + ) # NOTE: using eInterpolationCubic to do interpolation causes error. + curve.KeyModifyEnd() + + +def _animateRotationKeyFrames(animLayer, node, transforms_mat, frameDuration): + rotations = [] + for nth in range(len(transforms_mat)): + rotations.append(np.rad2deg(mat2euler(transforms_mat[nth][0:3, 0:3], axes="sxyz"))) + + _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration) + _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration) + _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration) + + +def _animateTranslationKeyFrames(animLayer, node, transforms_mat, frameDuration): + translations = [] + for nth in range(len(transforms_mat)): + translations.append(transforms_mat[nth][0:3, 3]) + + _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration) + + +def _animateScalingKeyFrames(animLayer, node, transforms_mat, frameDuration): + scalings = [] + for nth in range(len(transforms_mat)): + scalings.append( + np.array( + ( + transforms_mat[nth][0, 0], + transforms_mat[nth][1, 1], + transforms_mat[nth][2, 2], + ) + ) + ) + + _animateSingleChannel(animLayer, node.LclTranslation, "X", scalings, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Y", scalings, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Z", scalings, frameDuration) + + +def _animateSkeleton(fbxScene, skeletonNodes, frames, frameRate, name="Take1"): + frameDuration = 1.0 / frameRate + + if name != "Take1": + subs = name.split("/") + name = subs[-1][:-5] + + animStack = fbx.FbxAnimStack.Create(fbxScene, name) + animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer") + animStack.AddMember(animLayer) + _animateGlobalTransformsFromTransMat( + animLayer=animLayer, + referenceNode=skeletonNodes[0], + global_translation=frames[:, 0, :3, 3], + frameDuration=frameDuration, + ) + + for nId in range(len(skeletonNodes)): + _animateRotationKeyFrames( + animLayer=animLayer, + node=skeletonNodes[nId], + transforms_mat=frames[:, nId], + frameDuration=frameDuration, + ) + + +def _saveScene(filename, fbxManager, fbxScene): + exporter = fbx.FbxExporter.Create(fbxManager, "") + isInitialized = exporter.Initialize(filename) + + if isInitialized is False: + raise Exception( + "Exporter failed to initialized. Error returned: {}".format(exporter.GetStatus().GetErrorString()) + ) + + exporter.Export(fbxScene) + exporter.Destroy() + + +def _get_offsets_from_beta(beta, smplx_params, return_template_mesh=True): + v_template = torch.FloatTensor(smplx_params["v_template"]).unsqueeze(0) + shape_dirs = torch.FloatTensor(smplx_params["shapedirs"]) + J_regressor = torch.FloatTensor(smplx_params["J_regressor"]) + + v_shaped = v_template + _blend_shapes(beta, shape_dirs) + J = _vertices2joints(J_regressor, v_shaped).squeeze(0).numpy() + + parents = smplx_params["kintree_table"][()][0] + parents[0] = -1 + Translates = J[()].copy() + Translates[1:] -= J[parents[1:]] + if not return_template_mesh: + return Translates + else: + return Translates, v_shaped + + +def _preprocess_smplx(smplx_params, source_anim_data, scale=1, debug=False): + Translates, v_shaped = _get_offsets_from_beta( + torch.FloatTensor(source_anim_data["betas"]), + smplx_params, + return_template_mesh=True, + ) + + parents = smplx_params["kintree_table"][()][0] + parents[0] = -1 + + poses = torch.FloatTensor(source_anim_data["poses"]) + source_LclRotation = angle_axis_to_rotation_matrix(poses).numpy() + source_LclTranslation = np.tile(Translates, (source_LclRotation.shape[0], 1, 1)) + source_LclTranslation[:, 0] += source_anim_data["trans"] + + source_skeleton = { + "parent": parents, + "LclRotation": source_LclRotation, + "LclTranslation": source_LclTranslation * scale, + "Translate": Translates * scale, + "v_shaped": v_shaped.squeeze(0).numpy() * scale, + } + return source_skeleton + + +def _convert_npz_to_fbx(smplh_params, npz_data, save_fn, fps=30, uv_coords=None, uv_faces=None): + kintree = smplh_params["kintree_table"][0] + kintree[0] = -1 + + source_anim_data = { + "betas": npz_data["betas"], + "poses": npz_data["poses"].reshape(npz_data["poses"].shape[0], -1, 3), + "trans": npz_data["trans"], + } + source_skeleton = _preprocess_smplx(smplh_params, source_anim_data, scale=100) + rot = rot_mat2trans_mat(source_skeleton["LclRotation"]) + trans = trans2trans_mat(source_skeleton["LclTranslation"]) + frame_data = np.einsum("Btnk,Btkm ->Btnm", trans, rot) + + fbxManager = fbx.FbxManager.Create() + fbxScene = fbx.FbxScene.Create(fbxManager, "") + timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps) + fbxScene.GetGlobalSettings().SetTimeMode(timeMode) + + geometryNode = _addSmplXMesh( + fbxScene, + source_skeleton["v_shaped"], + smplh_params["f"], + uv_coords=uv_coords, + uv_faces=uv_faces, + ) + referenceNode, skeletonNodes = _addSmplXSkeleton( + fbxManager, + fbxScene=fbxScene, + trans=source_skeleton["Translate"], + joint2num=SMPLH_JOINT2NUM, + kintree_table=kintree, + ) + + _addSkiningWeight(fbxScene, smplh_params["weights"], geometryNode, skeletonNodes) + _storeBindPose(fbxScene, geometryNode) + _animateSkeleton( + fbxScene=fbxScene, + skeletonNodes=skeletonNodes, + frames=frame_data, + frameRate=fps, + ) + + with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f: + temp_file = tmp_f.name + + try: + # Save to temporary location + _saveScene(temp_file, fbxManager, fbxScene) + # If successful, copy to final destination + shutil.copy2(temp_file, save_fn) + except Exception as e: + print(f"Error saving FBX file: {e}") + finally: + # Remove temporary file + if os.path.exists(temp_file): + os.remove(temp_file) + + # CLEANUP + fbxManager.Destroy() + del fbxManager, fbxScene + + +def _read_uv(obj_template): + uv_coords = None + uv_faces = None + if obj_template and os.path.isfile(obj_template): + try: + print("Loading UV coordinates from OBJ template: {}".format(obj_template)) + obj_vertices, uv_coords, obj_faces, uv_faces = _parse_obj_file(obj_template) + print("Loaded {} UV coordinates and {} UV faces".format(len(uv_coords), len(uv_faces))) + except Exception as e: + print("Warning: Failed to load UV coordinates from OBJ file: {}".format(e)) + uv_coords = None + uv_faces = None + return uv_coords, uv_faces + + +class SMPLH2FBX: + def __init__( + self, + obj_template="./assets/smpl_family_models/smplh/textures/male_smplh.obj", + smplh_model_path="./assets/body_models/smplh/neutral/model.npz", + ): + print(f"[{self.__class__.__name__}] Load obj_template: {obj_template}") + self.uv_coords, self.uv_faces = _read_uv(obj_template) + print(f"[{self.__class__.__name__}] Load smplh_model_path: {smplh_model_path}") + self.smplh_params = dict(np.load(smplh_model_path, allow_pickle=True)) + + def convert_npz_to_fbx(self, npz_file, outname, fps=30): + os.makedirs(os.path.dirname(outname), exist_ok=True) + if isinstance(npz_file, str) and os.path.isfile(npz_file): + npz_data = dict(np.load(npz_file, allow_pickle=True)) + else: + npz_data = npz_file + _convert_npz_to_fbx( + self.smplh_params, + npz_data, + outname, + uv_coords=self.uv_coords, + uv_faces=self.uv_faces, + ) + return os.path.exists(outname) + + def convert_params_to_fbx(self, params, outname): + fps = params.get("mocap_framerate", 30) + os.makedirs(os.path.dirname(outname), exist_ok=True) + assert len(params["poses"].shape) == 3, f"poses shape should be (F, 52, 3), but got {params['poses'].shape}" + assert len(params["betas"].shape) == 2, f"betas shape should be (1, 16), but got {params['betas'].shape}" + assert len(params["trans"].shape) == 2, f"trans shape should be (1, 3), but got {params['trans'].shape}" + _convert_npz_to_fbx( + self.smplh_params, + params, + outname, + fps=fps, + uv_coords=self.uv_coords, + uv_faces=self.uv_faces, + ) + return os.path.exists(outname) + + +if __name__ == "__main__": + # python hymotion/utils/smplh2fbx.py + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("root", type=str) + args = parser.parse_args() + + converter = SMPLH2FBX() + + if os.path.isdir(args.root): + npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz"))) + else: + if args.root.endswith(".npz"): + npzfiles = [args.root] + else: + raise ValueError(f"Unknown file type: {args.root}") + + for npzfile in npzfiles: + converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx")) diff --git a/hymotion/utils/smplh2woodfbx.py b/hymotion/utils/smplh2woodfbx.py new file mode 100644 index 0000000000000000000000000000000000000000..303398a9f75a8f72aff684bef5bcaf3a2208cc66 --- /dev/null +++ b/hymotion/utils/smplh2woodfbx.py @@ -0,0 +1,702 @@ +import glob +import os +import shutil +import tempfile +from typing import Dict, Optional + +import fbx +import numpy as np +import torch +from transforms3d.euler import mat2euler + +from .geometry import angle_axis_to_rotation_matrix + +# yapf: disable +SMPLH_JOINT2NUM = { + "Pelvis": 0, "L_Hip": 1, "R_Hip": 2, "Spine1": 3, + "L_Knee": 4, "R_Knee": 5, "Spine2": 6, + "L_Ankle": 7, "R_Ankle": 8, + "Spine3": 9, + "L_Foot": 10, "R_Foot": 11, + "Neck": 12, "L_Collar": 13, "R_Collar": 14, "Head": 15, + "L_Shoulder": 16, "R_Shoulder": 17, + "L_Elbow": 18, "R_Elbow": 19, + "L_Wrist": 20, "R_Wrist": 21, + "L_Index1": 22, "L_Index2": 23, "L_Index3": 24, + "L_Middle1": 25, "L_Middle2": 26, "L_Middle3": 27, + "L_Pinky1": 28, "L_Pinky2": 29, "L_Pinky3": 30, + "L_Ring1": 31, "L_Ring2": 32, "L_Ring3": 33, + "L_Thumb1": 34, "L_Thumb2": 35, "L_Thumb3": 36, + "R_Index1": 37, "R_Index2": 38, "R_Index3": 39, + "R_Middle1": 40, "R_Middle2": 41, "R_Middle3": 42, + "R_Pinky1": 43, "R_Pinky2": 44, "R_Pinky3": 45, + "R_Ring1": 46, "R_Ring2": 47, "R_Ring3": 48, + "R_Thumb1": 49, "R_Thumb2": 50, "R_Thumb3": 51, +} + +# Mapping from SMPL-H joint names to lowercase names used in some FBX templates +SMPLH_TO_LOWERCASE_MAPPING = { + "Pelvis": "pelvis", + "L_Hip": "left_hip", + "R_Hip": "right_hip", + "Spine1": "spine1", + "L_Knee": "left_knee", + "R_Knee": "right_knee", + "Spine2": "spine2", + "L_Ankle": "left_ankle", + "R_Ankle": "right_ankle", + "Spine3": "spine3", + "L_Foot": "left_foot", + "R_Foot": "right_foot", + "Neck": "neck", + "L_Collar": "left_collar", + "R_Collar": "right_collar", + "Head": "head", + "L_Shoulder": "left_shoulder", + "R_Shoulder": "right_shoulder", + "L_Elbow": "left_elbow", + "R_Elbow": "right_elbow", + "L_Wrist": "left_wrist", + "R_Wrist": "right_wrist", + "L_Index1": "left_index1", + "L_Index2": "left_index2", + "L_Index3": "left_index3", + "L_Middle1": "left_middle1", + "L_Middle2": "left_middle2", + "L_Middle3": "left_middle3", + "L_Pinky1": "left_pinky1", + "L_Pinky2": "left_pinky2", + "L_Pinky3": "left_pinky3", + "L_Ring1": "left_ring1", + "L_Ring2": "left_ring2", + "L_Ring3": "left_ring3", + "L_Thumb1": "left_thumb1", + "L_Thumb2": "left_thumb2", + "L_Thumb3": "left_thumb3", + "R_Index1": "right_index1", + "R_Index2": "right_index2", + "R_Index3": "right_index3", + "R_Middle1": "right_middle1", + "R_Middle2": "right_middle2", + "R_Middle3": "right_middle3", + "R_Pinky1": "right_pinky1", + "R_Pinky2": "right_pinky2", + "R_Pinky3": "right_pinky3", + "R_Ring1": "right_ring1", + "R_Ring2": "right_ring2", + "R_Ring3": "right_ring3", + "R_Thumb1": "right_thumb1", + "R_Thumb2": "right_thumb2", + "R_Thumb3": "right_thumb3", +} +# yapf: enable + + +def _loadFbxScene(fbxManager, filepath): + """Load an FBX file into a scene""" + importer = fbx.FbxImporter.Create(fbxManager, "") + + if not importer.Initialize(filepath, -1, fbxManager.GetIOSettings()): + raise Exception( + f"Failed to initialize FBX importer for: {filepath}\nError: {importer.GetStatus().GetErrorString()}" + ) + + fbxScene = fbx.FbxScene.Create(fbxManager, "") + importer.Import(fbxScene) + importer.Destroy() + + return fbxScene + + +def _collectAllNodes(node, nodes_dict=None): + """Recursively collect all nodes in the scene hierarchy""" + if nodes_dict is None: + nodes_dict = {} + + nodes_dict[node.GetName()] = node + + for i in range(node.GetChildCount()): + _collectAllNodes(node.GetChild(i), nodes_dict) + + return nodes_dict + + +def _collectSkeletonNodes(node, skeleton_nodes=None): + """Recursively collect skeleton/bone nodes""" + if skeleton_nodes is None: + skeleton_nodes = {} + + # Check if this node has a skeleton attribute + attr = node.GetNodeAttribute() + if attr and attr.GetAttributeType() == fbx.FbxNodeAttribute.EType.eSkeleton: + skeleton_nodes[node.GetName()] = node + + for i in range(node.GetChildCount()): + _collectSkeletonNodes(node.GetChild(i), skeleton_nodes) + + return skeleton_nodes + + +def _animateSingleChannel(animLayer, component, name, values, frameDuration): + """Animate a single channel (X, Y, or Z) with keyframes""" + ncomp = {"X": 0, "Y": 1, "Z": 2}.get(name, 0) + + time = fbx.FbxTime() + curve = component.GetCurve(animLayer, name, True) + curve.KeyModifyBegin() + for nth in range(len(values)): + time.SetSecondDouble(nth * frameDuration) + keyIndex = curve.KeyAdd(time)[0] + curve.KeySetValue(keyIndex, values[nth][ncomp]) + curve.KeySetInterpolation(keyIndex, fbx.FbxAnimCurveDef.EInterpolationType.eInterpolationConstant) + curve.KeyModifyEnd() + + +def _animateRotationKeyFrames(animLayer, node, rot_matrices, frameDuration): + """Animate rotation keyframes for a node using rotation matrices""" + rotations = [] + for nth in range(len(rot_matrices)): + # Convert rotation matrix to Euler angles (XYZ order) + euler = np.rad2deg(mat2euler(rot_matrices[nth], axes="sxyz")) + rotations.append(euler) + + _animateSingleChannel(animLayer, node.LclRotation, "X", rotations, frameDuration) + _animateSingleChannel(animLayer, node.LclRotation, "Y", rotations, frameDuration) + _animateSingleChannel(animLayer, node.LclRotation, "Z", rotations, frameDuration) + + +def _animateTranslationKeyFrames(animLayer, node, translations, frameDuration): + """Animate translation keyframes for a node""" + # Ensure translations is a numpy array with shape (num_frames, 3) + if isinstance(translations, torch.Tensor): + translations = translations.numpy() + translations = np.asarray(translations, dtype=np.float64) + + if len(translations.shape) == 1: + # Single frame, reshape to (1, 3) + translations = translations.reshape(1, -1) + + _animateSingleChannel(animLayer, node.LclTranslation, "X", translations, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Y", translations, frameDuration) + _animateSingleChannel(animLayer, node.LclTranslation, "Z", translations, frameDuration) + + +def _clearExistingAnimations(fbxScene): + """Remove all existing animation stacks from the scene""" + anim_stack_count = fbxScene.GetSrcObjectCount(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId)) + for i in range(anim_stack_count - 1, -1, -1): + anim_stack = fbxScene.GetSrcObject(fbx.FbxCriteria.ObjectType(fbx.FbxAnimStack.ClassId), i) + if anim_stack: + anim_stack.Destroy() + + +def _applyAnimationToSkeleton(fbxScene, nodes_map, rot_matrices, translations, fps, smplh_to_fbx_mapping, name="Take1"): + """ + Apply SMPL-H animation data to skeleton nodes in the FBX scene. + + Args: + fbxScene: FBX scene object + nodes_map: Dictionary of node_name -> FbxNode + rot_matrices: (num_frames, num_joints, 3, 3) rotation matrices + translations: (num_frames, 3) root translations (relative displacement, not absolute position) + fps: Frame rate + smplh_to_fbx_mapping: Mapping from SMPL-H joint names to FBX node names + name: Animation take name + """ + frameDuration = 1.0 / fps + num_frames = rot_matrices.shape[0] + num_joints = rot_matrices.shape[1] + + # Create animation stack and layer + animStack = fbx.FbxAnimStack.Create(fbxScene, name) + animLayer = fbx.FbxAnimLayer.Create(fbxScene, "Base Layer") + animStack.AddMember(animLayer) + + # Track if root translation was applied + root_translation_applied = False + root_node = None + + # Get root node's initial LclTranslation from template (this is like Translates[0] in smplh2woodfbx.py) + root_initial_translation = None + root_fbx_name = smplh_to_fbx_mapping.get("Pelvis") + if root_fbx_name and root_fbx_name in nodes_map: + root_node_temp = nodes_map[root_fbx_name] + initial_trans = root_node_temp.LclTranslation.Get() + root_initial_translation = np.array([initial_trans[0], initial_trans[1], initial_trans[2]]) + print(f"Root initial LclTranslation from template: {root_initial_translation}") + + # Animate each joint + for smplh_joint_name, smplh_joint_idx in SMPLH_JOINT2NUM.items(): + if smplh_joint_idx >= num_joints: + continue + + # Get the FBX node name from mapping + fbx_node_name = smplh_to_fbx_mapping.get(smplh_joint_name) + if not fbx_node_name: + if smplh_joint_idx == 0: + print(f"Warning: Root joint 'Pelvis' not found in mapping!") + continue + + # Find the node + node = nodes_map.get(fbx_node_name) + if not node: + print(f"Warning: Joint '{smplh_joint_name}' (FBX: '{fbx_node_name}') not found in scene") + continue + + # Animate rotation + _animateRotationKeyFrames( + animLayer=animLayer, + node=node, + rot_matrices=rot_matrices[:, smplh_joint_idx], + frameDuration=frameDuration, + ) + + # Animate translation for root joint (Pelvis) + if smplh_joint_idx == 0: + root_node = node + # Add initial offset to translations (like smplh2woodfbx.py does: Translates[0] + trans) + # The translations input is relative displacement, we need to add the template's initial position + if root_initial_translation is not None: + final_translations = translations + root_initial_translation + print( + f"Applying root translation to '{fbx_node_name}', frames={num_frames}, " + f"initial_offset={root_initial_translation}, " + f"final translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}" + ) + else: + final_translations = translations + print( + f"Applying root translation to '{fbx_node_name}', frames={num_frames}, " + f"translation range: {final_translations.min(axis=0)} to {final_translations.max(axis=0)}" + ) + _animateTranslationKeyFrames( + animLayer=animLayer, + node=node, + translations=final_translations, + frameDuration=frameDuration, + ) + root_translation_applied = True + + # If root translation was not applied, try to find root node by common names + if not root_translation_applied: + print("Warning: Root translation was not applied through normal mapping, trying fallback...") + root_candidates = ["Pelvis", "pelvis", "Hips", "hips", "Root", "root", "mixamorig:Hips"] + for candidate in root_candidates: + if candidate in nodes_map: + root_node = nodes_map[candidate] + # Get initial translation for fallback node + initial_trans = root_node.LclTranslation.Get() + fallback_initial = np.array([initial_trans[0], initial_trans[1], initial_trans[2]]) + final_translations = translations + fallback_initial + print( + f"Found root node by fallback: '{candidate}', initial_offset={fallback_initial}, applying translation..." + ) + _animateTranslationKeyFrames( + animLayer=animLayer, + node=root_node, + translations=final_translations, + frameDuration=frameDuration, + ) + root_translation_applied = True + break + + if not root_translation_applied: + print("ERROR: Could not find root node to apply translation!") + print(f"Available nodes: {list(nodes_map.keys())}") + + return animStack + + +def _saveScene(filename, fbxManager, fbxScene, embed_textures=True): + """Save the FBX scene to a file + + Args: + filename: Output file path + fbxManager: FBX manager instance + fbxScene: FBX scene to save + embed_textures: Whether to embed textures/media in the FBX file (default True) + """ + # Configure IOSettings to embed textures/media + ios = fbxManager.GetIOSettings() + if embed_textures: + ios.SetBoolProp(fbx.EXP_FBX_EMBEDDED, True) + ios.SetBoolProp(fbx.EXP_FBX_MATERIAL, True) + ios.SetBoolProp(fbx.EXP_FBX_TEXTURE, True) + + exporter = fbx.FbxExporter.Create(fbxManager, "") + isInitialized = exporter.Initialize(filename, -1, ios) + + if isInitialized is False: + raise Exception(f"Exporter failed to initialize. Error: {exporter.GetStatus().GetErrorString()}") + + exporter.Export(fbxScene) + exporter.Destroy() + + +def _convert_smplh_to_woodfbx( + template_fbx_path, + npz_data, + save_fn, + fps=30, + scale=100, + smplh_to_fbx_mapping=None, + clear_animations=True, +): + """ + Convert SMPL-H parameters to FBX using a template FBX file. + The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters. + + Args: + template_fbx_path: Path to the template FBX file (e.g., boy_Rigging_smplx.fbx) + npz_data: Dictionary containing SMPL-H parameters + - poses: (num_frames, 52, 3) or (num_frames, 156) + - trans: (num_frames, 3) + save_fn: Output FBX file path + fps: Frame rate + scale: Scale factor for translation (default 100 for m to cm conversion) + smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names + clear_animations: Whether to clear existing animations in the template + + Returns: + bool: True if successful + """ + # Prepare poses data + poses = npz_data["poses"] + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses).float() + + if len(poses.shape) == 2: + # (num_frames, 156) -> (num_frames, 52, 3) + poses = poses.reshape(poses.shape[0], -1, 3) + + # Convert axis-angle to rotation matrices: (num_frames, num_joints, 3, 3) + rot_matrices = angle_axis_to_rotation_matrix(poses).numpy() + + # Prepare translation data + trans = npz_data["trans"] + if isinstance(trans, torch.Tensor): + trans = trans.numpy() + + # Apply scale to translation + translations = trans * scale + + # Create FBX manager and load template + fbxManager = fbx.FbxManager.Create() + ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT) + fbxManager.SetIOSettings(ios) + + print(f"Loading FBX template: {template_fbx_path}") + fbxScene = _loadFbxScene(fbxManager, template_fbx_path) + + # Set time mode + timeMode = fbx.FbxTime().ConvertFrameRateToTimeMode(fps) + fbxScene.GetGlobalSettings().SetTimeMode(timeMode) + + # Collect all nodes + rootNode = fbxScene.GetRootNode() + all_nodes = _collectAllNodes(rootNode) + skeleton_nodes = _collectSkeletonNodes(rootNode) + + print(f"Found {len(all_nodes)} nodes in scene") + print(f"Found {len(skeleton_nodes)} skeleton nodes: {list(skeleton_nodes.keys())}") + + # Use default mapping if not provided + if smplh_to_fbx_mapping is None: + smplh_to_fbx_mapping = _auto_detect_mapping(all_nodes) + print(f"Auto-detected {len(smplh_to_fbx_mapping)} joint mappings") + if "Pelvis" in smplh_to_fbx_mapping: + print(f" Root joint 'Pelvis' mapped to: '{smplh_to_fbx_mapping['Pelvis']}'") + else: + print(f" WARNING: Root joint 'Pelvis' not found in mapping!") + print(f" Available nodes: {list(all_nodes.keys())[:20]}...") # Show first 20 nodes + + # Clear existing animations if requested + if clear_animations: + _clearExistingAnimations(fbxScene) + + # Apply animation to skeleton + _applyAnimationToSkeleton( + fbxScene=fbxScene, + nodes_map=all_nodes, + rot_matrices=rot_matrices, + translations=translations, + fps=fps, + smplh_to_fbx_mapping=smplh_to_fbx_mapping, + name="SMPLH_Animation", + ) + + # Save to temporary file first, then copy to final destination + os.makedirs(os.path.dirname(save_fn) if os.path.dirname(save_fn) else ".", exist_ok=True) + with tempfile.NamedTemporaryFile(suffix=".fbx", delete=False) as tmp_f: + temp_file = tmp_f.name + + try: + _saveScene(temp_file, fbxManager, fbxScene) + shutil.copy2(temp_file, save_fn) + os.remove(temp_file) + print(f"Successfully saved FBX to: {save_fn}") + except Exception as e: + print(f"Error saving FBX file: {e}") + return False + finally: + fbxManager.Destroy() + del fbxManager, fbxScene + + return os.path.exists(save_fn) + + +def _auto_detect_mapping(all_nodes): + """Auto-detect the mapping from SMPL-H joints to FBX nodes""" + mapping = {} + for smplh_name in SMPLH_JOINT2NUM.keys(): + # Try exact match + if smplh_name in all_nodes: + mapping[smplh_name] = smplh_name + # Try lowercase version + elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in all_nodes: + mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name] + return mapping + + +class SMPLH2WoodFBX: + """ + Class to convert SMPL-H parameters to FBX using a template FBX file. + The template FBX skeleton is already consistent with SMPL-H, so we directly copy parameters. + No SMPL-H model assets (model.npz) required. + + Example usage: + converter = SMPLH2WoodFBX( + template_fbx_path="./assets/wooden_models/boy_Rigging_smplx.fbx" + ) + + # From npz file + converter.convert_npz_to_fbx("motion.npz", "output.fbx", fps=30) + + # From parameters dict + params = { + "poses": poses_array, # (num_frames, 52, 3) or (num_frames, 156) + "trans": trans_array, # (num_frames, 3) + } + converter.convert_params_to_fbx(params, "output.fbx") + """ + + def __init__( + self, + template_fbx_path: str = "./assets/wooden_models/boy_Rigging_smplx_tex.fbx", + smplh_to_fbx_mapping: Optional[Dict[str, str]] = None, + scale: float = 100, + ): + """ + Initialize the converter. + + Args: + template_fbx_path: Path to the template FBX file + smplh_to_fbx_mapping: Custom mapping from SMPL-H joint names to FBX node names + scale: Scale factor for translation (default 100 for m to cm conversion) + """ + print(f"[{self.__class__.__name__}] Template FBX: {template_fbx_path}") + self.template_fbx_path = template_fbx_path + self.smplh_to_fbx_mapping = smplh_to_fbx_mapping + self.scale = scale + + # Analyze template FBX to detect joint names + self._analyze_template() + + def _analyze_template(self): + """Analyze the template FBX file to detect available skeleton nodes""" + fbxManager = fbx.FbxManager.Create() + ios = fbx.FbxIOSettings.Create(fbxManager, fbx.IOSROOT) + fbxManager.SetIOSettings(ios) + + try: + fbxScene = _loadFbxScene(fbxManager, self.template_fbx_path) + rootNode = fbxScene.GetRootNode() + + self.all_template_nodes = list(_collectAllNodes(rootNode).keys()) + self.skeleton_template_nodes = list(_collectSkeletonNodes(rootNode).keys()) + + print(f"[{self.__class__.__name__}] Template nodes: {len(self.all_template_nodes)}") + print(f"[{self.__class__.__name__}] Skeleton nodes: {self.skeleton_template_nodes}") + + # Auto-detect mapping if not provided + if self.smplh_to_fbx_mapping is None: + self.smplh_to_fbx_mapping = self._auto_detect_mapping() + print(f"[{self.__class__.__name__}] Auto-detected {len(self.smplh_to_fbx_mapping)} joint mappings") + finally: + fbxManager.Destroy() + + def _auto_detect_mapping(self): + """Auto-detect the mapping from SMPL-H joints to FBX nodes""" + mapping = {} + for smplh_name in SMPLH_JOINT2NUM.keys(): + # Try exact match + if smplh_name in self.all_template_nodes: + mapping[smplh_name] = smplh_name + # Try lowercase version + elif SMPLH_TO_LOWERCASE_MAPPING.get(smplh_name) in self.all_template_nodes: + mapping[smplh_name] = SMPLH_TO_LOWERCASE_MAPPING[smplh_name] + return mapping + + def convert_npz_to_fbx(self, npz_file, outname, fps=30, clear_animations=True): + """ + Convert an npz file containing SMPL-H parameters to FBX. + + Args: + npz_file: Path to the npz file or dict containing SMPL-H parameters + outname: Output FBX file path + fps: Frame rate + clear_animations: Whether to clear existing animations in template + + Returns: + bool: True if successful + """ + os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True) + + if isinstance(npz_file, str) and os.path.isfile(npz_file): + npz_data = dict(np.load(npz_file, allow_pickle=True)) + else: + npz_data = npz_file + + return _convert_smplh_to_woodfbx( + template_fbx_path=self.template_fbx_path, + npz_data=npz_data, + save_fn=outname, + fps=fps, + scale=self.scale, + smplh_to_fbx_mapping=self.smplh_to_fbx_mapping, + clear_animations=clear_animations, + ) + + def convert_params_to_fbx(self, params, outname, clear_animations=True): + """ + Convert SMPL-H parameters to FBX. + + Args: + params: Dictionary containing SMPL-H parameters + - poses: (num_frames, 52, 3) or (num_frames, 156) + - trans: (num_frames, 3) + - mocap_framerate (optional): Frame rate + outname: Output FBX file path + clear_animations: Whether to clear existing animations in template + + Returns: + bool: True if successful + """ + fps = params.get("mocap_framerate", 30) + os.makedirs(os.path.dirname(outname) if os.path.dirname(outname) else ".", exist_ok=True) + + npz_data = { + "poses": params["poses"], + "trans": params["trans"], + } + + return _convert_smplh_to_woodfbx( + template_fbx_path=self.template_fbx_path, + npz_data=npz_data, + save_fn=outname, + fps=fps, + scale=self.scale, + smplh_to_fbx_mapping=self.smplh_to_fbx_mapping, + clear_animations=clear_animations, + ) + + +LEFT_HAND_MEAN_AA = [ 0.1117, 0.0429, -0.4164, 0.1088, -0.0660, -0.7562, -0.0964, -0.0909, + -0.1885, -0.1181, 0.0509, -0.5296, -0.1437, 0.0552, -0.7049, -0.0192, + -0.0923, -0.3379, -0.4570, -0.1963, -0.6255, -0.2147, -0.0660, -0.5069, + -0.3697, -0.0603, -0.0795, -0.1419, -0.0859, -0.6355, -0.3033, -0.0579, + -0.6314, -0.1761, -0.1321, -0.3734, 0.8510, 0.2769, -0.0915, -0.4998, + 0.0266, 0.0529, 0.5356, 0.0460, -0.2774] +RIGHT_HAND_MEAN_AA = [ 0.1117, -0.0429, 0.4164, 0.1088, 0.0660, 0.7562, -0.0964, 0.0909, + 0.1885, -0.1181, -0.0509, 0.5296, -0.1437, -0.0552, 0.7049, -0.0192, + 0.0923, 0.3379, -0.4570, 0.1963, 0.6255, -0.2147, 0.0660, 0.5069, + -0.3697, 0.0603, 0.0795, -0.1419, 0.0859, 0.6355, -0.3033, 0.0579, + 0.6314, -0.1761, 0.1321, 0.3734, 0.8510, -0.2769, 0.0915, -0.4998, + -0.0266, -0.0529, 0.5356, -0.0460, 0.2774] + +def construct_smpl_data_dict( + rot6d, + transl, + betas=None, + gender="neutral", + use_default_hand_mean_pose=False, +) -> dict: + rotation_matrix = rot6d_to_rotation_matrix(rot6d) + angle_axis = rotation_matrix_to_angle_axis(rotation_matrix) + left_hand_mean_pose = ( + torch.tensor( + LEFT_HAND_MEAN_AA, + device=angle_axis.device, + dtype=angle_axis.dtype, + ) + .unsqueeze(0) + .repeat(angle_axis.shape[0], 1) + .reshape(angle_axis.shape[0], -1, 3) + ) + right_hand_mean_pose = ( + torch.tensor( + RIGHT_HAND_MEAN_AA, + device=angle_axis.device, + dtype=angle_axis.dtype, + ) + .unsqueeze(0) + .repeat(angle_axis.shape[0], 1) + .reshape(angle_axis.shape[0], -1, 3) + ) + if angle_axis.shape[1] == 22: + angle_axis = torch.cat( + [ + angle_axis, + left_hand_mean_pose, + right_hand_mean_pose, + ], + dim=1, + ) + elif angle_axis.shape[1] == 52: + if use_default_hand_mean_pose: + angle_axis = torch.cat( + [ + angle_axis[:, :22], + left_hand_mean_pose, + right_hand_mean_pose, + ], + dim=1, + ) + else: + angle_axis = angle_axis + + assert angle_axis.shape[1] == 52, f"angle_axis should be 52, but got {angle_axis.shape[1]}" + dump = { + "betas": betas.cpu().numpy() if betas is not None else np.zeros((1, 16)), + "gender": gender, + "poses": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1), + "trans": transl.cpu().numpy(), + "mocap_framerate": 30, + "num_frames": angle_axis.shape[0], + "Rh": angle_axis.cpu().numpy().reshape(angle_axis.shape[0], -1)[:, :3], + } + return dump + +if __name__ == "__main__": + # python hymotion/utils/smplh2woodfbx.py + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("root", type=str) + args = parser.parse_args() + + converter = SMPLH2WoodFBX( + template_fbx_path="./assets/wooden_models/boy_Rigging_smplx_tex.fbx", + scale=100, + ) + + if os.path.isdir(args.root): + npzfiles = sorted(glob.glob(os.path.join(args.root, "*.npz"))) + else: + if args.root.endswith(".npz"): + npzfiles = [args.root] + else: + raise ValueError(f"Unknown file type: {args.root}") + + for npzfile in npzfiles: + converter.convert_npz_to_fbx(npzfile, npzfile.replace(".npz", ".fbx").replace("motions", "motions_fbx")) diff --git a/hymotion/utils/t2m_runtime.py b/hymotion/utils/t2m_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..1a94cfde452b0da1b77a3668db0fe39cd94985a7 --- /dev/null +++ b/hymotion/utils/t2m_runtime.py @@ -0,0 +1,378 @@ +# t2m_runtime.py +import os +import threading +import time +import uuid +from typing import List, Optional, Tuple, Union + +import torch +import yaml + +from ..prompt_engineering.prompt_rewrite import PromptRewriter +from .loaders import load_object +from .visualize_mesh_web import save_visualization_data + +try: + import fbx + + FBX_AVAILABLE = True + print(">>> FBX module found.") +except ImportError: + FBX_AVAILABLE = False + print(">>> FBX module not found.") + + +def _get_local_ip(): + import subprocess + + result = subprocess.run(["hostname", "-I"], capture_output=True, text=True, timeout=5) + if result.returncode == 0: + for ip in result.stdout.strip().split(): + if not ip.startswith("127.") and not ip.startswith("172.17."): + return ip + return "localhost" + + +def _now(): + t = time.time() + ms = int((t - int(t)) * 1000) + return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}" + + +class T2MRuntime: + def __init__( + self, + config_path: str, + ckpt_name: str = "latest.ckpt", + skip_text: bool = False, + device_ids: Union[list[int], None] = None, + prompt_engineering_host: Optional[str] = None, + skip_model_loading: bool = False, + ): + self.config_path = config_path + self.ckpt_name = ckpt_name + self.skip_text = skip_text + self.prompt_engineering_host = prompt_engineering_host + self.skip_model_loading = skip_model_loading + self.local_ip = _get_local_ip() + + # Check for CPU-only mode via environment variable + # Set HY_MOTION_DEVICE=cpu to force CPU mode + force_cpu = os.environ.get("HY_MOTION_DEVICE", "").lower() == "cpu" + if force_cpu: + print(">>> [INFO] CPU mode enabled via HY_MOTION_DEVICE=cpu environment variable") + self.device_ids = [] + elif torch.cuda.is_available(): + all_ids = list(range(torch.cuda.device_count())) + self.device_ids = all_ids if device_ids is None else [i for i in device_ids if i in all_ids] + else: + self.device_ids = [] + + self.pipelines = [] + self._gpu_load = [] + self._lock = threading.Lock() + self._loaded = False + + self.prompt_rewriter = PromptRewriter(backend="our_rewriter", host=self.prompt_engineering_host) + + # Skip model loading if checkpoint not found + if self.skip_model_loading: + print(">>> [WARNING] Model loading skipped - checkpoint not found") + self._loaded = True # Mark as loaded to prevent further load attempts + else: + self.load() + self.fbx_available = FBX_AVAILABLE + if self.fbx_available: + try: + from .smplh2woodfbx import SMPLH2WoodFBX + + self.fbx_converter = SMPLH2WoodFBX() + except Exception as e: + print(f">>> Failed to initialize FBX converter: {e}") + self.fbx_available = False + self.fbx_converter = None + else: + self.fbx_converter = None + print(">>> FBX module not found. FBX export will be disabled.") + + device_info = self.device_ids if self.device_ids else 'cpu' + if self.skip_model_loading: + print(f">>> T2MRuntime initialized (model NOT loaded) in IP {self.local_ip}, devices={device_info}") + else: + print(f">>> T2MRuntime loaded in IP {self.local_ip}, devices={device_info}") + + def load(self): + if self._loaded: + return + print(f">>> Loading model from {self.config_path}...") + + with open(self.config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + if not self.device_ids: + pipeline = load_object( + config["train_pipeline"], + config["train_pipeline_args"], + network_module=config["network_module"], + network_module_args=config["network_module_args"], + ) + device = torch.device("cpu") + pipeline.load_in_demo( + self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text + ) + pipeline.to(device) + self.pipelines = [pipeline] + self._gpu_load = [0] + else: + for gid in self.device_ids: + p = load_object( + config["train_pipeline"], + config["train_pipeline_args"], + network_module=config["network_module"], + network_module_args=config["network_module_args"], + ) + p.load_in_demo(self.ckpt_name, os.path.dirname(self.ckpt_name), build_text_encoder=not self.skip_text) + p.to(torch.device(f"cuda:{gid}")) + self.pipelines.append(p) + self._gpu_load = [0] * len(self.pipelines) + + self._loaded = True + + def _acquire_pipeline(self) -> int: + while True: + with self._lock: + for i in range(len(self._gpu_load)): + if self._gpu_load[i] == 0: + self._gpu_load[i] = 1 + return i + time.sleep(0.01) + + def _release_pipeline(self, idx: int): + with self._lock: + self._gpu_load[idx] = 0 + + def test_dit_inference(self, duration: float = 2.0, seed: int = 42) -> bool: + """ + Test DiT model inference with unconditional/blank input. + This method is used to verify the DiT model works before loading text encoder. + + Args: + duration: Duration of the test motion in seconds + seed: Random seed for reproducibility + + Returns: + True if inference succeeds and produces valid output + """ + if not self.pipelines: + raise RuntimeError("No pipeline loaded. Call load() first.") + + pi = self._acquire_pipeline() + try: + pipeline = self.pipelines[pi] + pipeline.eval() + device = next(pipeline.parameters()).device + + # Calculate frame length from duration (assuming 30fps output, 20fps internal) + length = int(duration * 20) + length = min(length, pipeline.train_frames) + + # Use null features for unconditional generation + batch_size = 1 + vtxt_input = pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device) + ctxt_input = pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device) + ctxt_length = torch.tensor([1] * batch_size, device=device) + + # Create masks + from ..pipeline.motion_diffusion import length_to_mask + + ctxt_mask_temporal = length_to_mask(ctxt_length, ctxt_input.shape[1]) + x_length = torch.LongTensor([length] * batch_size).to(device) + x_mask_temporal = length_to_mask(x_length, pipeline.train_frames) + + # Run denoising inference + print(f"\t>>> Running DiT inference test: length={length}, device={device}") + + # Create random noise + generator = torch.Generator(device=device).manual_seed(seed) + latent_shape = (batch_size, pipeline.train_frames, pipeline.mean.shape[-1]) + latents = torch.randn(latent_shape, generator=generator, device=device, dtype=vtxt_input.dtype) + + # Simple single-step denoising test (just forward pass) + with torch.no_grad(): + # Get timestep + timesteps = torch.tensor([0.5], device=device, dtype=vtxt_input.dtype).expand(batch_size) + + # Forward pass through DiT + # Use correct parameter names for HunyuanMotionMMDiT.forward() + _ = pipeline.motion_transformer( + x=latents, + ctxt_input=ctxt_input, + vtxt_input=vtxt_input, + timesteps=timesteps, + x_mask_temporal=x_mask_temporal, + ctxt_mask_temporal=ctxt_mask_temporal, + ) + + print(f"\t>>> DiT forward pass completed successfully!") + return True + + except Exception as e: + print(f"\t>>> DiT inference test failed: {e}") + raise + finally: + self._release_pipeline(pi) + + def load_text_encoder(self) -> None: + """ + Load text encoder for all pipelines. + This is called after DiT model testing to complete the initialization. + """ + if not self.pipelines: + raise RuntimeError("No pipeline loaded. Call load() first.") + + print(">>> Loading text encoder for all pipelines...") + for i, pipeline in enumerate(self.pipelines): + if not hasattr(pipeline, "text_encoder") or pipeline.text_encoder is None: + device = next(pipeline.parameters()).device + pipeline.text_encoder = load_object(pipeline._text_encoder_module, pipeline._text_encoder_cfg) + pipeline.text_encoder.to(device) + print(f"\t>>> Text encoder loaded for pipeline {i} on {device}") + + # Update skip_text flag + self.skip_text = False + print(">>> Text encoder loading completed!") + + def rewrite_text_and_infer_time(self, text: str) -> Tuple[float, str]: + print("Start rewriting text...") + duration, rewritten_text = self.prompt_rewriter.rewrite_prompt_and_infer_time(f"{text}") + print(f"\t>>> Rewritten text: {rewritten_text}, duration: {duration:.2f} seconds") + return duration, rewritten_text + + def generate_motion( + self, + text: str, + seeds_csv: str, + duration: float, + cfg_scale: float, + output_format: str = "fbx", + output_dir: Optional[str] = None, + output_filename: Optional[str] = None, + original_text: Optional[str] = None, + use_special_game_feat: bool = False, + ) -> Tuple[Union[str, list[str]], dict]: + # Check if model was skipped due to missing checkpoint + if self.skip_model_loading: + raise RuntimeError( + "Motion generation is not available: model checkpoint was not found. " + "Please ensure the checkpoint file exists at the specified path." + ) + + self.load() + seeds = [int(s.strip()) for s in seeds_csv.split(",") if s.strip() != ""] + pi = self._acquire_pipeline() + try: + pipeline = self.pipelines[pi] + pipeline.eval() + + # When skip_text=True (debug mode), use blank text features + if self.skip_text: + print(">>> [Debug Mode] Using blank text features (skip_text=True)") + device = next(pipeline.parameters()).device + batch_size = len(seeds) if seeds else 1 + # Create blank hidden_state_dict using null features + hidden_state_dict = { + "text_vec_raw": pipeline.null_vtxt_feat.expand(batch_size, -1, -1).to(device), + "text_ctxt_raw": pipeline.null_ctxt_input.expand(batch_size, -1, -1).to(device), + "text_ctxt_raw_length": torch.tensor([1] * batch_size, device=device), + } + # Disable CFG in debug mode (use cfg_scale=1.0) + model_output = pipeline.generate( + text, + seeds, + duration, + cfg_scale=1.0, + use_special_game_feat=False, + hidden_state_dict=hidden_state_dict, + ) + else: + model_output = pipeline.generate( + text, seeds, duration, cfg_scale=cfg_scale, use_special_game_feat=use_special_game_feat + ) + finally: + self._release_pipeline(pi) + + ts = _now() + save_data, base_filename = save_visualization_data( + output=model_output, + text=text if original_text is None else original_text, + rewritten_text=text, + timestamp=ts, + output_dir=output_dir, + output_filename=output_filename, + ) + + view_url = self._generate_html_view_url( + timestamp=ts, + file_path=base_filename, + output_dir=output_dir, + ) + + if output_format == "fbx" and not self.fbx_available: + print(">>> Warning: FBX export requested but FBX SDK is not available. Falling back to html.") + output_format = "html" + + if output_format == "fbx" and self.fbx_available: + fbx_files = self._generate_fbx_files( + visualization_data=save_data, + output_dir=output_dir, + fbx_filename=output_filename, + ) + return view_url, fbx_files, model_output + else: + raise ValueError(f">>> Invalid output format: {output_format}") + + def _generate_html_view_url( + self, + timestamp: str, + file_path: str, + output_dir: Optional[str] = None, + ) -> str: + print(f">>> HTML ready, timestamp: {timestamp}") + gradio_dir = output_dir if output_dir is not None else "output/gradio" + view_url = f"/view/{gradio_dir}/{file_path}" + return view_url + + def _generate_fbx_files( + self, + visualization_data: dict, + output_dir: Optional[str] = None, + fbx_filename: Optional[str] = None, + ) -> List[str]: + assert "smpl_data" in visualization_data, "smpl_data not found in visualization_data" + fbx_files = [] + if output_dir is None: + root_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + output_dir = os.path.join(root_dir, "output", "gradio") + + smpl_data_list = visualization_data["smpl_data"] + + unique_id = str(uuid.uuid4())[:8] + text = visualization_data["text"] + timestamp = visualization_data["timestamp"] + for bb in range(len(smpl_data_list)): + smpl_data = smpl_data_list[bb] + if fbx_filename is None: + fbx_filename_bb = f"{timestamp}_{unique_id}_{bb:03d}.fbx" + else: + fbx_filename_bb = f"{fbx_filename}_{bb:03d}.fbx" + fbx_path = os.path.join(output_dir, fbx_filename_bb) + success = self.fbx_converter.convert_npz_to_fbx(smpl_data, fbx_path) + if success: + fbx_files.append(fbx_path) + print(f"\t>>> FBX file generated: {fbx_path}") + txt_path = fbx_path.replace(".fbx", ".txt") + with open(txt_path, "w", encoding="utf-8") as f: + f.write(text) + fbx_files.append(txt_path) + + return fbx_files diff --git a/hymotion/utils/type_converter.py b/hymotion/utils/type_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1d02af7557548aff10fcff7c2e734cc6ca339a --- /dev/null +++ b/hymotion/utils/type_converter.py @@ -0,0 +1,22 @@ +import torch +from torch import nn + + +def get_module_device(module: nn.Module) -> torch.device: + """Get the device of a module. + + Args: + module (nn.Module): A module contains the parameters. + + Returns: + torch.device: The device of the module. + """ + try: + next(module.parameters()) + except StopIteration: + raise ValueError("The input module should contain parameters.") + + if next(module.parameters()).is_cuda: + return torch.device(next(module.parameters()).get_device()) + + return torch.device("cpu") diff --git a/hymotion/utils/visualize_mesh_web.py b/hymotion/utils/visualize_mesh_web.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9f52fcfdaaf6f54dd2abb5a60179c15444fc99 --- /dev/null +++ b/hymotion/utils/visualize_mesh_web.py @@ -0,0 +1,342 @@ +import json +import os +import re +import threading +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor + +_FILE_ACCESS_LOCK = threading.Lock() + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize filename to prevent path traversal attacks + Args: + filename: original filename + Returns: + sanitized filename + """ + if not filename: + return "" + + # remove all path traversal characters + filename = re.sub(r"\.\.(/|\\\\\\)?", "", filename) + filename = filename.strip("./\\") + + # only allow letters, numbers, underscores, hyphens and dots + # dots are only allowed once in the extension + filename = re.sub(r"[^a-zA-Z0-9_.-]", "", filename) + + # prevent multiple consecutive dots + while ".." in filename: + filename = filename.replace("..", ".") + + # prevent starting with a dot (hidden file) + if filename.startswith("."): + filename = filename[1:] + + # limit file name length + if len(filename) > 255: + filename = filename[:255] + + return filename + + +def sanitize_folder_name(folder_name: str) -> str: + """ + Sanitize folder name to prevent path traversal attacks + Args: + folder_name: original folder name + Returns: + sanitized folder name + """ + if not folder_name: + return "output" # default folder + + # remove all path traversal characters + folder_name = re.sub(r"\.\.(/|\\\\\\)?", "", folder_name) + folder_name = folder_name.strip("./\\") + + # only allow letters, numbers, underscores, hyphens and slashes (for subdirectories) + # but need to ensure slashes don't cause path traversal + folder_name = re.sub(r"[^a-zA-Z0-9_./-]", "", folder_name) + + # split path and clean each part + parts = folder_name.split("/") + cleaned_parts = [] + for part in parts: + if part and part not in [".", ".."]: + # clean each part + part = re.sub(r"[^a-zA-Z0-9_-]", "", part) + if part: + cleaned_parts.append(part) + + # recombine, allow at most 3 levels of directory depth + if len(cleaned_parts) > 3: + cleaned_parts = cleaned_parts[:3] + + return "/".join(cleaned_parts) if cleaned_parts else "output" + + +def safe_path_join(base_dir: str, *paths: str) -> str: + """ + Safe path joining, ensure the resulting path is within base_dir + Args: + base_dir: base directory + *paths: paths to join + Returns: + joined path + Raises: + ValueError: if path traversal is detected + """ + # clean all paths + cleaned_paths = [] + for path in paths: + if path: + # clean each path part + path = re.sub(r"\.\.(/|\\\\\\)?", "", path) + path = path.strip("./\\") + path = re.sub(r"[^a-zA-Z0-9_.-]", "", path) + if path: + cleaned_paths.append(path) + + # join paths + full_path = os.path.join(base_dir, *cleaned_paths) + + # ensure the resulting path is within base_dir + base_dir = os.path.realpath(base_dir) + full_path = os.path.realpath(os.path.normpath(full_path)) + + if os.path.commonpath([base_dir, full_path]) != base_dir: + raise ValueError(f"Path traversal detected: {full_path} is outside {base_dir}") + + return full_path + + +def _get_root_dir() -> str: + return os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + +def get_output_dir(sub_path: str = "") -> str: + output_base = _get_root_dir() + if not os.path.exists(output_base): + os.makedirs(output_base, exist_ok=True) + if sub_path: + parts = [p for p in sub_path.replace("\\", "/").split("/") if p] + else: + parts = [] + return safe_path_join(output_base, *parts) + + +def save_visualization_data( + output: Dict[str, Union[Tensor, list[str]]], + text: str, + rewritten_text: Union[str, list[str]], + timestamp: str, + output_dir: Optional[str] = None, + output_filename: Optional[str] = None, +): + from ..pipeline.body_model import construct_smpl_data_dict + + if output_dir is None: + output_dir = get_output_dir(sub_path="output/gradio") + os.makedirs(output_dir, exist_ok=True) + + # for metadata + base_filename = output_filename if output_filename else timestamp + meta_path = safe_path_join(output_dir, f"{base_filename}_meta.json") + if isinstance(rewritten_text, str): + rewritten_text = [rewritten_text] + batch_size = output["rot6d"].shape[0] + meta_data = { + "timestamp": timestamp, + "text": text, + "text_rewrite": rewritten_text, + "num_samples": batch_size, + "base_filename": base_filename, + } + + with _FILE_ACCESS_LOCK: + with open(meta_path, "w") as f: + json.dump(meta_data, f, indent=2) + + # for smpl data + rot6d = output["rot6d"] + transl = output["transl"] + + all_smpl_data = [] # for FBX generator + + for bb in range(batch_size): + # build data + smpl_data = construct_smpl_data_dict(rot6d[bb].clone(), transl[bb].clone()) + all_smpl_data.append(smpl_data) + + # prepare dictionary to save into NPZ + npz_dict = {} + npz_dict["gender"] = np.array([smpl_data.get("gender", "neutral")], dtype=str) + + for key in ["Rh", "trans", "poses", "betas"]: + if key in smpl_data: + val = smpl_data[key] + if isinstance(val, (list, tuple)): + val = np.array(val) + elif isinstance(val, torch.Tensor): + val = val.cpu().numpy() + npz_dict[key] = val + + # save single NPZ + sample_filename = f"{base_filename}_{bb:03d}.npz" + sample_path = safe_path_join(output_dir, sample_filename) + + with _FILE_ACCESS_LOCK: + np.savez_compressed(sample_path, **npz_dict) + + # construct memory dictionary to return (for compatibility) + memory_data = { + "timestamp": timestamp, + "text": text, + "text_rewrite": rewritten_text, + "smpl_data": all_smpl_data, + "meta_data": [], + } + + # return base filename, subsequent logic will use this as a basis for finding _meta.json or _000.npz + return memory_data, base_filename + + +def get_cached_captions(folder_name: str, file_name: str) -> List[dict]: + """read _meta.json to get text""" + + folder_name = sanitize_folder_name(folder_name) + file_name = sanitize_filename(file_name) + + base_dir = get_output_dir(folder_name) + # try to add suffix or find + meta_path = safe_path_join(base_dir, f"{file_name}_meta.json") + + if not os.path.exists(meta_path): + if "_" in file_name: + prefix = file_name.rsplit("_", 1)[0] + prefix = sanitize_filename(prefix) + meta_path_alt = safe_path_join(base_dir, f"{prefix}_meta.json") + if os.path.exists(meta_path_alt): + meta_path = meta_path_alt + else: + return [] + else: + return [] + + try: + with _FILE_ACCESS_LOCK: + with open(meta_path, "r") as f: + data = json.load(f) + + text = data.get("text", "") + text_rewrite = data.get("text_rewrite", []) + + captions = [] + for i, t in enumerate(text_rewrite): + item = {"short caption+": f"{t}", "start_time": None, "end_time": None} + if text and text != t: + item["short caption"] = text + captions.append(item) + return captions + except Exception as e: + print(f"Error reading meta json: {e}") + return [] + + +def get_cached_smpl_frames(folder_name: str, file_name: str) -> List[list]: + """ + read logic needs to be adjusted: + 1. if file_name is the base name, load all samples + 2. if file_name is a specific sample name, only load that sample + """ + folder_name = sanitize_folder_name(folder_name) + file_name = sanitize_filename(file_name) + + base_dir = get_output_dir(folder_name) + + npz_direct_path = safe_path_join(base_dir, f"{file_name}.npz") + meta_path = safe_path_join(base_dir, f"{file_name}_meta.json") + + target_indices = [] + base_name = file_name + + if os.path.isfile(npz_direct_path): + try: + if "_" in file_name: + prefix, suffix = file_name.rsplit("_", 1) + if suffix.isdigit(): + num_samples = 1 + base_name = prefix + target_indices = [int(suffix)] + else: + pass + else: + pass + except ValueError: + pass + if not target_indices: + return [] + elif os.path.exists(meta_path): + try: + with open(meta_path, "r") as f: + meta = json.load(f) + num_samples = meta.get("num_samples", 0) + target_indices = range(num_samples) + except Exception as e: + print(f"Error reading meta: {e}") + return [] + else: + return [] + + all_people = [] + + for i in target_indices: + npz_path = safe_path_join(base_dir, f"{base_name}_{i:03d}.npz") + if not os.path.exists(npz_path): + continue + + try: + with _FILE_ACCESS_LOCK: + with np.load(npz_path, allow_pickle=False) as data: + # read single person data + gender = str(data["gender"][0]) + Rh = data["Rh"] + Th = data["trans"] + poses = data["poses"] + betas = data["betas"] + + if poses.ndim == 3: + poses = poses.reshape(poses.shape[0], -1) + + person_frames = [] + for f in range(len(poses)): + frame = { + "id": i, + "gender": gender, + "Rh": Rh[f : f + 1].tolist(), + "Th": Th[f : f + 1].tolist(), + "poses": poses[f : f + 1].tolist(), + "shapes": betas.tolist(), + } + person_frames.append([frame]) + all_people.append(person_frames) + except Exception as e: + print(f"Error loading {npz_path}: {e}") + + # merge + combined_frames = [] + max_frames = max(len(p) for p in all_people) if all_people else 0 + for f_idx in range(max_frames): + frame_content = [] + for person_seq in all_people: + if f_idx < len(person_seq): + frame_content.extend(person_seq[f_idx]) + combined_frames.append(frame_content) + + return combined_frames diff --git a/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp new file mode 100644 index 0000000000000000000000000000000000000000..7e515e321713c3fd3dfd5f756b0fb21d42e0f84f --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_BaseColor.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aaed26cb89635f9d995c9c373919185c412f00f6b4a903873ac75ccd7a549439 +size 228228 diff --git a/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp new file mode 100644 index 0000000000000000000000000000000000000000..90cf7f0ca103ef8f9332169349fe7187ff978449 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_Normal.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe4bd8b80aadf6e414c3c07acc57532cda754deb93f0cdc6f211387b8beca97d +size 30430 diff --git a/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp new file mode 100644 index 0000000000000000000000000000000000000000..1f0412ab97068809d350bf8157dde22893413c3f --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/Boy_lambert4_OcclusionRoughnessMetallic.webp @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ad8bd31dfda3a8356ac5c9d6bdc7457f0197ae23754f9ac196f86d872b26781 +size 206722 diff --git a/scripts/gradio/static/assets/dump_wooden/faces.bin b/scripts/gradio/static/assets/dump_wooden/faces.bin new file mode 100644 index 0000000000000000000000000000000000000000..06ceddf590710f03ca4b1ec517993e820d6597e6 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/faces.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:777b0806d2843c797ed18644eecc11466ff822b33b02d263c22f8ad3730e9bb5 +size 290376 diff --git a/scripts/gradio/static/assets/dump_wooden/j_template.bin b/scripts/gradio/static/assets/dump_wooden/j_template.bin new file mode 100644 index 0000000000000000000000000000000000000000..cf614efd7c18d17aac67192196c87371b68c5adc --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/j_template.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929 +size 624 diff --git a/scripts/gradio/static/assets/dump_wooden/joint_names.json b/scripts/gradio/static/assets/dump_wooden/joint_names.json new file mode 100644 index 0000000000000000000000000000000000000000..95415b3786ca61223d666faad98b0452e6f99300 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/joint_names.json @@ -0,0 +1,54 @@ +[ + "Pelvis", + "L_Hip", + "R_Hip", + "Spine1", + "L_Knee", + "R_Knee", + "Spine2", + "L_Ankle", + "R_Ankle", + "Spine3", + "L_Foot", + "R_Foot", + "Neck", + "L_Collar", + "R_Collar", + "Head", + "L_Shoulder", + "R_Shoulder", + "L_Elbow", + "R_Elbow", + "L_Wrist", + "R_Wrist", + "L_Index1", + "L_Index2", + "L_Index3", + "L_Middle1", + "L_Middle2", + "L_Middle3", + "L_Pinky1", + "L_Pinky2", + "L_Pinky3", + "L_Ring1", + "L_Ring2", + "L_Ring3", + "L_Thumb1", + "L_Thumb2", + "L_Thumb3", + "R_Index1", + "R_Index2", + "R_Index3", + "R_Middle1", + "R_Middle2", + "R_Middle3", + "R_Pinky1", + "R_Pinky2", + "R_Pinky3", + "R_Ring1", + "R_Ring2", + "R_Ring3", + "R_Thumb1", + "R_Thumb2", + "R_Thumb3" +] \ No newline at end of file diff --git a/scripts/gradio/static/assets/dump_wooden/joints.ply b/scripts/gradio/static/assets/dump_wooden/joints.ply new file mode 100644 index 0000000000000000000000000000000000000000..cf70efc97b3f86c1dc0021c5dbd0f5f3d6c9fe0e Binary files /dev/null and b/scripts/gradio/static/assets/dump_wooden/joints.ply differ diff --git a/scripts/gradio/static/assets/dump_wooden/keypoints.bin b/scripts/gradio/static/assets/dump_wooden/keypoints.bin new file mode 100644 index 0000000000000000000000000000000000000000..cf614efd7c18d17aac67192196c87371b68c5adc --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/keypoints.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f488cc9d0b650a816f5a8eda49eda7bc796f490e9130a6e0dec5be137d7b929 +size 624 diff --git a/scripts/gradio/static/assets/dump_wooden/kintree.bin b/scripts/gradio/static/assets/dump_wooden/kintree.bin new file mode 100644 index 0000000000000000000000000000000000000000..1f0261f9d2d2b7daae9a6775b0f769a62e3a90af --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/kintree.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98a20fa3b53193790b63d9ac3a9c917f2f70fbe6e053dca495c75317ff4b756a +size 208 diff --git a/scripts/gradio/static/assets/dump_wooden/skinIndice.bin b/scripts/gradio/static/assets/dump_wooden/skinIndice.bin new file mode 100644 index 0000000000000000000000000000000000000000..3819f48cdc4d6e2175c2663855919164f20b5fe9 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/skinIndice.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:846794fb90ea01e069435ad242caefb3d5c2f913fef3255c247f48836ddc1bda +size 194048 diff --git a/scripts/gradio/static/assets/dump_wooden/skinWeights.bin b/scripts/gradio/static/assets/dump_wooden/skinWeights.bin new file mode 100644 index 0000000000000000000000000000000000000000..a6fb57ebf01a3a7991a0d710340b876a39d3e2af --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/skinWeights.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:343eac2902627ee6b45b547eb7d9f1526562eca7ba178d1dc71f9e466f307a77 +size 388096 diff --git a/scripts/gradio/static/assets/dump_wooden/uvs.bin b/scripts/gradio/static/assets/dump_wooden/uvs.bin new file mode 100644 index 0000000000000000000000000000000000000000..7c8abfc32bb0cb46a0f35bf6d1217abb96a3ce37 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/uvs.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23c92c60b609261d927990d31cbf6ace0c14cb70a5ac753a5f3927cb8c5c8191 +size 194048 diff --git a/scripts/gradio/static/assets/dump_wooden/v_template.bin b/scripts/gradio/static/assets/dump_wooden/v_template.bin new file mode 100644 index 0000000000000000000000000000000000000000..364b797b82a60e30a8b9287ccea86d8167120f79 --- /dev/null +++ b/scripts/gradio/static/assets/dump_wooden/v_template.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9fbe1f34bfe8a07442d11166e169318022a18da8bc62ce0a9930dfdb3171050 +size 291072 diff --git a/scripts/gradio/static/scripts3d/create_ground.js b/scripts/gradio/static/scripts3d/create_ground.js new file mode 100644 index 0000000000000000000000000000000000000000..7124a9db7508a48053d225ba2da1b88fbfe36049 --- /dev/null +++ b/scripts/gradio/static/scripts3d/create_ground.js @@ -0,0 +1,191 @@ +import * as THREE from "three"; + +// extract common adaptive logic +function getAdaptiveGridSize(sample_data, default_size = 5) { + if (sample_data) { + const bounds = calculateDataBounds(sample_data); + const grid_size = Math.max(bounds.maxRange * 3, 5); // 1.5x margin + console.log(`Adaptive ground size: ${grid_size.toFixed(2)}, data range: ${bounds.maxRange.toFixed(2)}`); + return grid_size; + } + return default_size; +} + +function createBaseChessboard( + grid_size = 5, + divisions = 10, + white = "#ffffff", + black = "#444444", + texture_size = 1024, + sample_data = null, +) { + // Use adaptive sizing if sample_data provided, otherwise use fixed grid_size + if (sample_data) { + grid_size = getAdaptiveGridSize(sample_data, grid_size); + } + + // Create chessboard texture with enhanced visual style + // Ensure texture_size is divisible by divisions to avoid sub-pixel rendering + var adjusted_texture_size = Math.floor(texture_size / divisions) * divisions; + var canvas = document.createElement("canvas"); + canvas.width = canvas.height = adjusted_texture_size; + var context = canvas.getContext("2d"); + + // Disable anti-aliasing for crisp edges + context.imageSmoothingEnabled = false; + + var step = adjusted_texture_size / divisions; // Now guaranteed to be an integer + for (var i = 0; i < divisions; i++) { + for (var j = 0; j < divisions; j++) { + context.fillStyle = (i + j) % 2 === 0 ? white : black; + context.fillRect(i * step, j * step, step, step); + } + } + + var texture = new THREE.CanvasTexture(canvas); + // Use NearestFilter for sharp/crisp edges between chess squares + texture.wrapS = THREE.RepeatWrapping; + texture.wrapT = THREE.RepeatWrapping; + texture.magFilter = THREE.NearestFilter; + texture.minFilter = THREE.NearestFilter; + texture.generateMipmaps = false; + + // Create plane geometry + var planeGeometry = new THREE.PlaneGeometry(grid_size, grid_size); + + // Enhanced material with better visual properties + var planeMaterial = new THREE.MeshStandardMaterial({ + map: texture, + side: THREE.DoubleSide, + transparent: true, + opacity: 0.85, + roughness: 0.9, + metalness: 0.1, + emissiveIntensity: 0.05, + }); + + // Create grid mesh + var plane = new THREE.Mesh(planeGeometry, planeMaterial); + plane.receiveShadow = true; + + return plane; +} + +function getChessboard(...args) { + var plane = createBaseChessboard(...args); + plane.rotation.x = -Math.PI; // rotate to make the plane horizontal + return plane; +} + +function getChessboardXZ(...args) { + var plane = createBaseChessboard(...args); + plane.rotation.x = -Math.PI / 2; // rotate to make the plane horizontal + return plane; +} + +function getCoordinate(axisLength) { + // create a group to store the coordinate axes + var axes = new THREE.Group(); + + // define the material of the axes + var materialX = new THREE.LineBasicMaterial({ color: 0xff0000 }); // red X axis + var materialY = new THREE.LineBasicMaterial({ color: 0x00ff00 }); // green Y axis + var materialZ = new THREE.LineBasicMaterial({ color: 0x0000ff }); // blue Z axis + + // create axis lines (X axis, Y axis, Z axis) + var xAxisGeometry = new THREE.BufferGeometry().setFromPoints([ + new THREE.Vector3(0, 0, 0), + new THREE.Vector3(axisLength, 0, 0), + ]); + var yAxisGeometry = new THREE.BufferGeometry().setFromPoints([ + new THREE.Vector3(0, 0, 0), + new THREE.Vector3(0, axisLength, 0), + ]); + var zAxisGeometry = new THREE.BufferGeometry().setFromPoints([ + new THREE.Vector3(0, 0, 0), + new THREE.Vector3(0, 0, axisLength), + ]); + + var xAxis = new THREE.Line(xAxisGeometry, materialX); + var yAxis = new THREE.Line(yAxisGeometry, materialY); + var zAxis = new THREE.Line(zAxisGeometry, materialZ); + + // add axes to the group + axes.add(xAxis); + axes.add(yAxis); + axes.add(zAxis); + + return axes; +} + +function calculateDataBounds(sample_data) { + let minX = Infinity, + maxX = -Infinity; + let minY = Infinity, + maxY = -Infinity; + let minZ = Infinity, + maxZ = -Infinity; + + // iterate through sample_data to find the maximum and minimum values + if (sample_data && sample_data.length > 0) { + sample_data.forEach((frame) => { + if (frame.positions && Array.isArray(frame.positions)) { + frame.positions.forEach((pos) => { + // support multiple position data formats + let x, y, z; + if (typeof pos === "object") { + x = pos.x !== undefined ? pos.x : pos[0]; + y = pos.y !== undefined ? pos.y : pos[1]; + z = pos.z !== undefined ? pos.z : pos[2]; + } else if (Array.isArray(pos)) { + [x, y, z] = pos; + } + + if (x !== undefined && y !== undefined && z !== undefined) { + minX = Math.min(minX, x); + maxX = Math.max(maxX, x); + minY = Math.min(minY, y); + maxY = Math.max(maxY, y); + minZ = Math.min(minZ, z); + maxZ = Math.max(maxZ, z); + } + }); + } + }); + } + + // if no valid data is found, use default values + if (minX === Infinity || maxX === -Infinity) { + minX = maxX = minY = maxY = minZ = maxZ = 0; + } + + const rangeX = Math.abs(maxX - minX); + const rangeY = Math.abs(maxY - minY); + const rangeZ = Math.abs(maxZ - minZ); + + // calculate the maximum range of the XZ plane (the ground mainly cares about the movement of the X and Z axes) + const maxRange = Math.max(rangeX, rangeZ); + + // add debug information + console.log( + `Data boundaries: X[${minX.toFixed(2)}, ${maxX.toFixed(2)}], Y[${minY.toFixed(2)}, ${maxY.toFixed(2)}], Z[${minZ.toFixed(2)}, ${maxZ.toFixed(2)}]`, + ); + console.log( + `Ranges: X=${rangeX.toFixed(2)}, Y=${rangeY.toFixed(2)}, Z=${rangeZ.toFixed(2)}, Max=${maxRange.toFixed(2)}`, + ); + + return { + minX, + maxX, + minY, + maxY, + minZ, + maxZ, + rangeX, + rangeY, + rangeZ, + maxRange, + }; +} + +export { calculateDataBounds, getChessboard, getChessboardXZ, getCoordinate }; diff --git a/scripts/gradio/static/scripts3d/create_scene.js b/scripts/gradio/static/scripts3d/create_scene.js new file mode 100644 index 0000000000000000000000000000000000000000..f512ac2f5a29f7756c69e5d5026863a25529f340 --- /dev/null +++ b/scripts/gradio/static/scripts3d/create_scene.js @@ -0,0 +1,195 @@ +import * as THREE from "three"; +import { getChessboard, getChessboardXZ, getCoordinate } from "./create_ground.js"; + +function create_plane(scene) { + const planeGeometry = new THREE.PlaneGeometry(20, 20); + const planeMaterial = new THREE.MeshStandardMaterial({ color: 0x808080 }); + const plane = new THREE.Mesh(planeGeometry, planeMaterial); + plane.position.y = -1; + plane.receiveShadow = true; // make the plane receive shadows + scene.add(plane); +} + +function create_cube(scene) { + // add a cube + const cubeGeometry = new THREE.BoxGeometry(); + const cubeMaterial = new THREE.MeshPhongMaterial({ color: 0xffffff }); + const cube = new THREE.Mesh(cubeGeometry, cubeMaterial); + cube.position.y = 1; + cube.castShadow = true; // make the cube cast shadows + scene.add(cube); +} + +function create_scene(scene, camera, renderer, use_ground = true, axis_up = "z", axis_forward = "-y") { + const width = document.querySelector(".container").offsetWidth; + const height = width; + + // Camera setup based on axis orientation + if (axis_up == "z") { + camera.up.set(0, 0, 1); + if (axis_forward == "-y") { + camera.position.set(0, -3, 3); + } else if (axis_forward == "y") { + camera.position.set(0, 3, 3); + } + camera.lookAt(new THREE.Vector3(0, 0, 1.5)); + } else if (axis_up == "y") { + camera.up.set(0, 1, 0); + if (axis_forward == "z") { + camera.position.set(0, 2.5, 5); + } else if (axis_forward == "-z") { + camera.position.set(0, 2.5, -5); + } + camera.lookAt(new THREE.Vector3(0, 1, 0)); + } + + scene.background = new THREE.Color(0x000000); + + // ===== Fog for depth perception ===== + // Using FogExp2 for natural exponential falloff, density ~0.06 + scene.fog = new THREE.FogExp2(0x424242, 0.06); + + // ===== Shadow Configuration ===== + renderer.shadowMap.enabled = true; + renderer.shadowMap.type = THREE.PCFSoftShadowMap; + + // ===== Enhanced Lighting Setup ===== + + // 1. Hemisphere Light - natural sky/ground ambient + const hemisphereLight = new THREE.HemisphereLight( + 0xffffff, // sky color + 0x444444, // ground color + 1.8 // intensity + ); + hemisphereLight.position.set(0, 2, 0); + scene.add(hemisphereLight); + + // 2. Main Directional Light (key light with shadows) + const directionalLight = new THREE.DirectionalLight(0xffffff, 1.5); + if (axis_up == "z") { + if (axis_forward == "-y") { + directionalLight.position.set(-3, 1, 5); + } else if (axis_forward == "y") { + directionalLight.position.set(3, 1, 5); + } + } else if (axis_up == "y") { + if (axis_forward == "z") { + directionalLight.position.set(3, 5, 4); + } else if (axis_forward == "-z") { + directionalLight.position.set(3, 5, -4); + } + } + directionalLight.castShadow = true; + directionalLight.shadow.mapSize.width = 2048; + directionalLight.shadow.mapSize.height = 2048; + directionalLight.shadow.camera.near = 0.5; + directionalLight.shadow.camera.far = 50; + directionalLight.shadow.camera.left = -10; + directionalLight.shadow.camera.right = 10; + directionalLight.shadow.camera.top = 10; + directionalLight.shadow.camera.bottom = -10; + directionalLight.shadow.bias = -0.0001; + scene.add(directionalLight); + + // 3. Fill Light (softer, from opposite side) + const fillLight = new THREE.DirectionalLight(0xaaccff, 0.4); + fillLight.position.set(-3, 3, -2); + scene.add(fillLight); + + // 4. Rim Light (back light for depth) + const rimLight = new THREE.DirectionalLight(0xffeedd, 0.3); + rimLight.position.set(0, 4, -5); + scene.add(rimLight); + + // ===== Ground Setup ===== + if (use_ground) { + if (axis_up == "z") { + var plane = getChessboard(50, 50, '#ffffff', '#3a3a3a', 1024); + plane.name = 'ground'; + plane.receiveShadow = true; + scene.add(plane); + } else if (axis_up == "y") { + var plane = getChessboardXZ(50, 50, '#ffffff', '#3a3a3a', 1024); + plane.name = 'ground'; + plane.receiveShadow = true; + scene.add(plane); + } + + // Optional: coordinate axes helper + // var coord = getCoordinate(1); + // scene.add(coord); + } + + return 0; +} + +function fitCameraToScene(scene, camera, controls = null, opts = {}) { + const { margin = 1.05, axis_up = "y", excludeNames = ["ground"] } = opts; + + const box = new THREE.Box3(); + const tmp = new THREE.Box3(); + let has = false; + + scene.traverse((obj) => { + if (!obj || !obj.visible) return; + if (obj.isLight) return; + const t = obj.type || ""; + if (t.endsWith("Helper")) return; + if (excludeNames && excludeNames.includes(obj.name)) return; + + if (obj.isMesh) { + if (obj.geometry && obj.geometry.type === "PlaneGeometry") return; + try { + tmp.setFromObject(obj); + if (!tmp.isEmpty()) { + if (!has) { + box.copy(tmp); + has = true; + } else { + box.union(tmp); + } + } + } catch (_) {} + } + }); + + if (!has || box.isEmpty()) return; + + const sphere = new THREE.Sphere(); + box.getBoundingSphere(sphere); + const center = sphere.center.clone(); + const radius = Math.max(sphere.radius, 1e-3); + + const vFov = THREE.MathUtils.degToRad(camera.fov); + const hFov = 2 * Math.atan(Math.tan(vFov / 2) * camera.aspect); + const distV = radius / Math.sin(vFov / 2); + const distH = radius / Math.sin(hFov / 2); + const dist = Math.max(distV, distH) * margin; + + // 25° top-down view (azimuth 45°, elevation 25°) + const elev = THREE.MathUtils.degToRad(25); + const azim = Math.PI / 4; + const horiz = Math.cos(elev); + let dir; + + if (axis_up === "y") { + dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.sin(elev), Math.cos(azim) * horiz); + camera.up.set(0, 1, 0); + } else { + dir = new THREE.Vector3(Math.sin(azim) * horiz, Math.cos(azim) * horiz, Math.sin(elev)); + camera.up.set(0, 0, 1); + } + + camera.position.copy(center).add(dir.multiplyScalar(dist)); + camera.updateProjectionMatrix(); + camera.lookAt(center); + + if (controls) { + controls.target.copy(center); + controls.minDistance = Math.max(radius * 0.2, 0.1); + controls.maxDistance = Math.max(dist * 3, controls.minDistance + 0.1); + controls.update(); + } +} + +export { create_scene, fitCameraToScene }; diff --git a/scripts/gradio/static/scripts3d/draw_skeleton.js b/scripts/gradio/static/scripts3d/draw_skeleton.js new file mode 100644 index 0000000000000000000000000000000000000000..709a33e64e4427e4540ca6e2d0f8fa41e3f2f874 --- /dev/null +++ b/scripts/gradio/static/scripts3d/draw_skeleton.js @@ -0,0 +1,121 @@ +import * as THREE from "three"; + +const defaultEdges = [ + [1, 0], + [2, 1], + [3, 2], + [4, 3], + [5, 1], + [6, 5], + [7, 6], + [8, 1], + [9, 8], + [10, 9], + [11, 10], + [12, 8], + [13, 12], + [14, 13], + [15, 0], + [16, 0], + [17, 15], + [18, 16], + [19, 14], + [20, 19], + [21, 14], + [22, 11], + [23, 22], + [24, 11], +]; + +var geometries = []; + +function clearGeometries(scene) { + geometries.forEach((obj) => { + scene.remove(obj); + if (obj.geometry) obj.geometry.dispose(); + if (obj.material) obj.material.dispose(); + }); + geometries = []; +} + +function drawJoints(keypoints, scene, radius_joint) { + const sphereGeometry = new THREE.SphereGeometry(radius_joint, 32, 32); + const sphereMaterial = new THREE.MeshStandardMaterial({ color: 0xff0000 }); + + keypoints.forEach((point) => { + // Check visibility if confidence score exists + if (point.length > 3 && point[3] < 0.1) { + return; + } + + const sphere = new THREE.Mesh(sphereGeometry, sphereMaterial); + sphere.position.set(point[0], point[1], point[2]); + geometries.push(sphere); + scene.add(sphere); + }); +} + +function drawLimbs(keypoints, edges, scene, radius_limb) { + const ellipsoidGeometry = new THREE.SphereGeometry(radius_limb, 32, 32); + const ellipsoidMaterial = new THREE.MeshStandardMaterial({ color: 0x0000ff }); + + edges.forEach((edge) => { + const idx1 = edge[0]; + const idx2 = edge[1]; + + // Validate indices + if (idx1 >= keypoints.length || idx2 >= keypoints.length) { + return; + } + + // Check visibility + const p1 = keypoints[idx1]; + const p2 = keypoints[idx2]; + if ( + (p1.length > 3 && p1[3] < 0.1) || + (p2.length > 3 && p2[3] < 0.1) + ) { + return; + } + + const start = new THREE.Vector3(p1[0], p1[1], p1[2]); + const end = new THREE.Vector3(p2[0], p2[1], p2[2]); + + const direction = new THREE.Vector3().subVectors(end, start); + const length = direction.length(); + + // create an ellipsoid + const ellipsoid = new THREE.Mesh(ellipsoidGeometry, ellipsoidMaterial); + + // scale: x,y = 1 (radius_limb), z matches length + ellipsoid.scale.set(1, 1, length / 2 / radius_limb); + + // position: midpoint + ellipsoid.position.addVectors(start, end).multiplyScalar(0.5); + + // rotation: point to end + ellipsoid.lookAt(end); + + geometries.push(ellipsoid); + scene.add(ellipsoid); + }); +} + +function drawSingleSkeleton(keypoints, edges, scene, radius_joint, radius_limb) { + drawJoints(keypoints, scene, radius_joint); + drawLimbs(keypoints, edges, scene, radius_limb); +} + +function visualizeSkeleton(keypoints, scene, radius_joint = 0.02, radius_limb = 0.03) { + clearGeometries(scene); + drawSingleSkeleton(keypoints, defaultEdges, scene, radius_joint, radius_limb); +} + +function visualizeAllSkeleton(infos, scene, radius_joint = 0.02, radius_limb = 0.03) { + clearGeometries(scene); + infos.forEach((info) => { + drawSingleSkeleton(info.keypoints3d, info.edges, scene, radius_joint, radius_limb); + }); +} + +export { visualizeAllSkeleton, visualizeSkeleton }; diff --git a/scripts/gradio/static/scripts3d/load_smpl.js b/scripts/gradio/static/scripts3d/load_smpl.js new file mode 100644 index 0000000000000000000000000000000000000000..af1d17d3a56c8aea33089affe7461d6c018ebba9 --- /dev/null +++ b/scripts/gradio/static/scripts3d/load_smpl.js @@ -0,0 +1,126 @@ +import * as THREE from "three"; + +const NUM_SKIN_WEIGHTS = 4; + +async function load_smpl_with_shapes(shapes, gender) { + const urls = { + neutral: [ + "/static/assets/dump_smplh/v_template.bin", + "/static/assets/dump_smplh/faces.bin", + "/static/assets/dump_smplh/skinWeights.bin", + "/static/assets/dump_smplh/skinIndice.bin", + "/static/assets/dump_smplh/j_template.bin", + ], + }[gender]; + const gender_color = { + neutral: 0xffffff, + male: 0x6495ed, // Cornflower blue (lighter blue) + female: 0xff6b81, // Light coral (softer red) + }; + + console.log(shapes.length); + const geometry = new THREE.BufferGeometry(); + const buffers = await Promise.all(urls.map((url) => fetch(url).then((response) => response.arrayBuffer()))); + const v_template = new Float32Array(buffers[0]); + const offsets = await Promise.all( + shapes.map((_, i) => + fetch("/static/assets/dump_smplh/shapeoffset_" + i + ".bin") + .then((response) => response.arrayBuffer()) + .then((buffer) => new Float32Array(buffer)), + ), + ); + const offsets_j = await Promise.all( + shapes.map((_, i) => + fetch("/static/assets/dump_smplh/shapeoffset_j_" + i + ".bin") + .then((response) => response.arrayBuffer()) + .then((buffer) => new Float32Array(buffer)), + ), + ); + offsets.forEach((offset, i) => { + for (let j = 0; j < v_template.length / 3; j++) { + v_template[3 * j] += offset[3 * j] * shapes[i]; + v_template[3 * j + 1] += offset[3 * j + 1] * shapes[i]; + v_template[3 * j + 2] += offset[3 * j + 2] * shapes[i]; + } + }); + const faces = new Uint16Array(buffers[1]); + const skinWeights = new Float32Array(buffers[2]); + const skinIndices = new Uint16Array(buffers[3]); + + const keypoints = new Float32Array(buffers[4]); + for (let i = 0; i < keypoints.length / 3; i++) { + console.log("keypoints", keypoints[3 * i], keypoints[3 * i + 1], keypoints[3 * i + 2]); + } + + offsets_j.forEach((offset_j, i) => { + console.log("shape id", i, shapes[i]); + console.log("keypoints", keypoints[0], keypoints[1], keypoints[2]); + console.log("offset_j", offset_j[0], offset_j[1], offset_j[2]); + + for (let j = 0; j < keypoints.length / 3; j++) { + keypoints[3 * j] += offset_j[3 * j] * shapes[i]; + keypoints[3 * j + 1] += offset_j[3 * j + 1] * shapes[i]; + keypoints[3 * j + 2] += offset_j[3 * j + 2] * shapes[i]; + } + }); + + // edges contain the skeleton link relationship + // const edges = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21]; + const edges = [ + -1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, + 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, + ]; + // assume jointPositions is a J x 3 array, each element is an array containing X, Y, Z coordinates + var rootBone = new THREE.Bone(); + rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]); + // scene.add(rootBone); + var bones = [rootBone]; + // create bones + for (let i = 1; i < keypoints.length / 3; i++) { + const bone = new THREE.Bone(); + const parentIndex = edges[i]; + bone.position.set( + keypoints[3 * i] - keypoints[3 * parentIndex], + keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1], + keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2], + ); + console.log(i, bone.position); + bones.push(bone); + bones[parentIndex].add(bone); + } + var skeleton = new THREE.Skeleton(bones); + geometry.setIndex(new THREE.BufferAttribute(faces, 1)); + + geometry.setAttribute("position", new THREE.BufferAttribute(v_template, 3)); + geometry.setAttribute("skinIndex", new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS)); + geometry.setAttribute("skinWeight", new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS)); + + geometry.computeVertexNormals(); + console.log(geometry); + const material = new THREE.MeshStandardMaterial({ + color: gender_color[gender], + skinning: true, + side: THREE.DoubleSide, + }); + var mesh = new THREE.SkinnedMesh(geometry, material); + mesh.castShadow = true; + mesh.receiveShadow = true; + mesh.add(bones[0]); + mesh.bind(skeleton); + return { bones, skeleton, mesh }; +} + +function reshapeArrayTo2D(float32Array, rows) { + const twoDArray = []; + const cols = float32Array.length / rows; + for (let i = 0; i < rows; i++) { + const row = new Float32Array(cols); + for (let j = 0; j < cols; j++) { + row[j] = float32Array[i * cols + j]; + } + twoDArray.push(row); + } + return twoDArray; +} + +export { load_smpl_with_shapes }; diff --git a/scripts/gradio/static/scripts3d/load_wooden.js b/scripts/gradio/static/scripts3d/load_wooden.js new file mode 100644 index 0000000000000000000000000000000000000000..24af2995afa3a9f319ae90f433b93e06fd5e9cfb --- /dev/null +++ b/scripts/gradio/static/scripts3d/load_wooden.js @@ -0,0 +1,167 @@ +import * as THREE from 'three'; + +const NUM_SKIN_WEIGHTS = 4; + +// SMPL-H joint names (52 joints) +const SMPLH_JOINT_NAMES = [ + "Pelvis", "L_Hip", "R_Hip", "Spine1", + "L_Knee", "R_Knee", "Spine2", + "L_Ankle", "R_Ankle", "Spine3", + "L_Foot", "R_Foot", "Neck", "L_Collar", "R_Collar", "Head", + "L_Shoulder", "R_Shoulder", "L_Elbow", "R_Elbow", + "L_Wrist", "R_Wrist", + "L_Index1", "L_Index2", "L_Index3", + "L_Middle1", "L_Middle2", "L_Middle3", + "L_Pinky1", "L_Pinky2", "L_Pinky3", + "L_Ring1", "L_Ring2", "L_Ring3", + "L_Thumb1", "L_Thumb2", "L_Thumb3", + "R_Index1", "R_Index2", "R_Index3", + "R_Middle1", "R_Middle2", "R_Middle3", + "R_Pinky1", "R_Pinky2", "R_Pinky3", + "R_Ring1", "R_Ring2", "R_Ring3", + "R_Thumb1", "R_Thumb2", "R_Thumb3", +]; + +// Default kintree (parent indices) for SMPL-H 52 joints +const DEFAULT_EDGES = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50]; + +/** + * Load wooden model from binary files + * @param {Array} shapes - Shape parameters (unused for wooden model) + * @param {string} gender - Gender parameter (unused for wooden model) + * @returns {Object} { bones, skeleton, mesh, jointNames } + */ +async function load_wooden(shapes, gender, basePath = '/static/assets/dump_wooden') { + console.log("Loading wooden model..."); + console.log(`Using base path: ${basePath}`); + + const urls = [ + `${basePath}/v_template.bin`, + `${basePath}/faces.bin`, + `${basePath}/skinWeights.bin`, + `${basePath}/skinIndice.bin`, + `${basePath}/j_template.bin`, + `${basePath}/uvs.bin`, + ]; + + // Try to load kintree + let edges = [...DEFAULT_EDGES]; + try { + const kintreeResponse = await fetch(`${basePath}/kintree.bin`); + if (kintreeResponse.ok) { + const kintreeBuffer = await kintreeResponse.arrayBuffer(); + edges = Array.from(new Int32Array(kintreeBuffer)); + console.log(`Loaded kintree with ${edges.length} joints`); + } + } catch (e) { + console.log('Using default kintree'); + } + + // Try to load joint names + let jointNames = [...SMPLH_JOINT_NAMES]; + try { + const namesResponse = await fetch(`${basePath}/joint_names.json`); + if (namesResponse.ok) { + jointNames = await namesResponse.json(); + console.log(`Loaded ${jointNames.length} joint names`); + } + } catch (e) { + console.log('Using default joint names'); + } + + // Load main buffers + const buffers = await Promise.all(urls.map(url => fetch(url).then(response => response.arrayBuffer()))); + const v_template = new Float32Array(buffers[0]); + const faces = new Uint16Array(buffers[1]); + const skinWeights = new Float32Array(buffers[2]); + const skinIndices = new Uint16Array(buffers[3]); + const keypoints = new Float32Array(buffers[4]); + const uvs = new Float32Array(buffers[5]); + + console.log(`Vertices: ${v_template.length / 3}, Faces: ${faces.length / 3}, Joints: ${keypoints.length / 3}`); + + // Create geometry + const geometry = new THREE.BufferGeometry(); + geometry.setAttribute('position', new THREE.BufferAttribute(v_template, 3)); + geometry.setIndex(new THREE.BufferAttribute(faces, 1)); + geometry.setAttribute('skinIndex', new THREE.BufferAttribute(skinIndices, NUM_SKIN_WEIGHTS)); + geometry.setAttribute('skinWeight', new THREE.BufferAttribute(skinWeights, NUM_SKIN_WEIGHTS)); + geometry.setAttribute('uv', new THREE.BufferAttribute(uvs, 2)); + + // Create bones + const numJoints = keypoints.length / 3; + + // Ensure edges array matches joint count + while (edges.length < numJoints) { + edges.push(0); + } + + // Root bone + var rootBone = new THREE.Bone(); + rootBone.position.set(keypoints[0], keypoints[1], keypoints[2]); + rootBone.name = jointNames[0] || 'Pelvis'; + var bones = [rootBone]; + + // Create child bones + for (let i = 1; i < numJoints; i++) { + const bone = new THREE.Bone(); + const parentIndex = edges[i]; + + if (parentIndex >= 0 && parentIndex < i) { + bone.position.set( + keypoints[3 * i] - keypoints[3 * parentIndex], + keypoints[3 * i + 1] - keypoints[3 * parentIndex + 1], + keypoints[3 * i + 2] - keypoints[3 * parentIndex + 2] + ); + bone.name = jointNames[i] || `Joint_${i}`; + bones.push(bone); + bones[parentIndex].add(bone); + console.log(`Joint ${i} (${bone.name}): parent=${parentIndex}, pos=${bone.position.toArray()}`); + } else { + console.warn(`Invalid parent index ${parentIndex} for joint ${i}, attaching to root`); + bone.position.set(0, 0, 0); + bone.name = jointNames[i] || `Joint_${i}`; + bones.push(bone); + bones[0].add(bone); + } + } + + var skeleton = new THREE.Skeleton(bones); + + geometry.computeVertexNormals(); + + // --- Texture Loading --- + const textureLoader = new THREE.TextureLoader(); + + async function loadTextureAsync(url, isSRGB = true) { + const tex = await textureLoader.loadAsync(url); + tex.flipY = false; + if (isSRGB) tex.colorSpace = THREE.SRGBColorSpace; + return tex; + } + + const [baseColorMap] = await Promise.all([ + loadTextureAsync(`${basePath}/Boy_lambert4_BaseColor.webp`, true), + ]); + + // Create material - PBR with textures (optimized for dark mode) + const material = new THREE.MeshStandardMaterial({ + map: baseColorMap, + roughness: 0.6, // Lower roughness for better light reflection + metalness: 0.2, // Lower metalness for more natural look + envMapIntensity: 1.5, // Enhanced environment lighting + }); + + var mesh = new THREE.SkinnedMesh(geometry, material); + mesh.castShadow = true; + mesh.receiveShadow = true; + mesh.add(bones[0]); + mesh.bind(skeleton); + + console.log(`Wooden model loaded: ${numJoints} joints, ${v_template.length / 3} vertices`); + + return { bones, skeleton, mesh, jointNames, edges }; +} + +export { DEFAULT_EDGES, load_wooden, NUM_SKIN_WEIGHTS, SMPLH_JOINT_NAMES }; + diff --git a/scripts/gradio/templates/element/blank.html b/scripts/gradio/templates/element/blank.html new file mode 100644 index 0000000000000000000000000000000000000000..14c930a0a7fee10f1d709f84af0714da36f42614 --- /dev/null +++ b/scripts/gradio/templates/element/blank.html @@ -0,0 +1,53 @@ + + + + + {% block title %} {% endblock %} + + + + + + + + + + + + + + {% block content_block %} + {% endblock %} + + {% block script_block %} + {% endblock %} + + + + + diff --git a/scripts/gradio/templates/error_file_not_found.html b/scripts/gradio/templates/error_file_not_found.html new file mode 100644 index 0000000000000000000000000000000000000000..90fe482fc8464ea9b4b67d14ddafd5aee01d24e5 --- /dev/null +++ b/scripts/gradio/templates/error_file_not_found.html @@ -0,0 +1,64 @@ + + + + + + + File not found - 404 + + + + +
+

404

+

Oops! File Not Found

+

We couldn't find the file you're looking for.
Please check the URL or try again later.

+
+ + + diff --git a/scripts/gradio/templates/index_smpl_gradio.html b/scripts/gradio/templates/index_smpl_gradio.html new file mode 100644 index 0000000000000000000000000000000000000000..665241d599cb18a360af1994eb6ab43ef6174819 --- /dev/null +++ b/scripts/gradio/templates/index_smpl_gradio.html @@ -0,0 +1,938 @@ +{% extends 'element/blank.html' %} + +{% block content_block %} + +
+ + {% if not hide_captions %} +
+
+
+ Loading action descriptions... +
+
+
+ {% endif %} + + +
+
+
+
+ + +
+
+
+ + +
+ +
+ +
+ +
+
+ 0 / 0 +
+
+ Loading... +
+
+ +
+ + + 1.0x +
+
+
+
+ + + + + + + + + + + + +{% if not hide_captions %} + +{% endif %} + +{% endblock %} diff --git a/scripts/gradio/templates/index_wooden_gradio.html b/scripts/gradio/templates/index_wooden_gradio.html new file mode 100644 index 0000000000000000000000000000000000000000..397d4c2d102ed86ccc2744453bd85fe9903055cc --- /dev/null +++ b/scripts/gradio/templates/index_wooden_gradio.html @@ -0,0 +1,1033 @@ +{% extends 'element/blank.html' %} + +{% block content_block %} + + +
+ +
+ + + {% if not hide_captions %} +
+
+
+ Loading action descriptions... +
+
+
+ {% endif %} + + +
+
+
+ +
+
+ 0 / 0 +
+
+
+ + +
+ Loading... +
+ + +
+ + + + 1.0x +
+
+ + + + + + + + + + + + +{% if not hide_captions %} + +{% endif %} + +{% endblock %} diff --git a/scripts/gradio/vis_gradio.py b/scripts/gradio/vis_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..67788c2660a2cdf2b3562d6f278469c3f5a93d7a --- /dev/null +++ b/scripts/gradio/vis_gradio.py @@ -0,0 +1,103 @@ +import os +import sys +from os import path as osp + +from flask import Flask, jsonify, render_template, request + +sys.path.append(osp.dirname(osp.dirname(osp.dirname(__file__)))) +from hymotion.utils.visualize_mesh_web import ( + get_cached_captions, + get_cached_smpl_frames, + get_output_dir, + sanitize_filename, + sanitize_folder_name, + safe_path_join, +) + +template_folder = os.path.join(os.path.dirname(__file__), "templates") +static_folder = os.path.join(os.path.dirname(__file__), "static") + +app = Flask(__name__, template_folder=template_folder, static_folder=static_folder) + + +@app.route("/") +def home(): + return "HMotion Visualization Server is Running. Use /view/ to access content.", 200 + + +@app.route("/view/") +@app.route("/view/") +def index(full_path: str = ""): + hide_captions = request.args.get("hide_captions", "0") == "1" + + # security check + if ".." in full_path or full_path.startswith("/"): + return "Invalid path", 403 + if "/" in full_path: + raw_folder, raw_file = full_path.rsplit("/", 1) + else: + raw_folder, raw_file = "", full_path + + folder_name = sanitize_folder_name(raw_folder) + file_name = sanitize_filename(raw_file) + + # remove possible suffix + for suffix in [".npz", ".json", ".h5"]: + if file_name.endswith(suffix): + file_name = file_name[: -len(suffix)] + break + + base_dir = get_output_dir(folder_name) + target_meta = safe_path_join(base_dir, f"{file_name}_meta.json") + target_npz = safe_path_join(base_dir, f"{file_name}.npz") + + # check if meta file exists + if os.path.isfile(target_meta) or os.path.isfile(target_npz): + return render_template( + "index_wooden_gradio.html", + folder_name=folder_name, + file_name=file_name, + mirror_name=None, + hide_captions=hide_captions, + next_file=None, + ) + else: + return ( + render_template( + "error_file_not_found.html", + folder_name=folder_name, + file_name=file_name, + full_path=target_meta, + file_type="meta.json", + original_folder=folder_name, + original_file=full_path, + ), + 404, + ) + + +@app.route("/query_smpl//") +def query_smpl(folder_name: str, file_name: str): + smpl_data = get_cached_smpl_frames(folder_name, file_name) + return jsonify(smpl_data) + + +@app.route("/query_caption//") +def query_caption(folder_name: str, file_name: str): + hide_captions = request.args.get("hide_captions", "0") == "1" + if hide_captions: + return jsonify({"result": []}) + captions = get_cached_captions(folder_name, file_name) + return jsonify({"result": captions}) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=8081) + args = parser.parse_args() + + print(f">>> Starting Flask server on {args.host}:{args.port}") + app.run(host=args.host, port=args.port, debug=False, threaded=True) diff --git a/scripts/gradio/vis_routes.py b/scripts/gradio/vis_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..96490550d681cfbcc1ca5a3f23efaed20b44a9b5 --- /dev/null +++ b/scripts/gradio/vis_routes.py @@ -0,0 +1,91 @@ +import os + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.templating import Jinja2Templates + +from hymotion.utils.visualize_mesh_web import ( + get_cached_captions, + get_cached_smpl_frames, + get_output_dir, + safe_path_join, + sanitize_filename, + sanitize_folder_name, +) + +# Define Router +router = APIRouter() + +# Define template directory +current_dir = os.path.dirname(os.path.abspath(__file__)) +templates = Jinja2Templates(directory=os.path.join(current_dir, "templates")) + + +@router.get("/wait_for_data") +async def wait_for_data(): + return {"status": "ready", "frames": 0} + + +@router.get("/view/{full_path:path}", response_class=HTMLResponse) +async def view_visualization(request: Request, full_path: str): + hide_captions = request.query_params.get("hide_captions", "0") == "1" + + # Security check and path parsing logic + if ".." in full_path or full_path.startswith("/"): + raise HTTPException(status_code=403, detail="Invalid path") + + if "/" in full_path: + raw_folder, raw_file = full_path.rsplit("/", 1) + else: + raw_folder, raw_file = "", full_path + + folder_name = sanitize_folder_name(raw_folder) + file_name = sanitize_filename(raw_file) + + for suffix in [".npz", ".json", ".h5"]: + if file_name.endswith(suffix): + file_name = file_name[: -len(suffix)] + break + + base_dir = get_output_dir(folder_name) + target_meta = safe_path_join(base_dir, f"{file_name}_meta.json") + target_npz = safe_path_join(base_dir, f"{file_name}.npz") + + if os.path.isfile(target_meta) or os.path.isfile(target_npz): + # FastAPI template rendering needs to pass in request + return templates.TemplateResponse( + "index_wooden_gradio.html", + { + "request": request, + "folder_name": folder_name, + "file_name": file_name, + "hide_captions": hide_captions, + }, + ) + else: + return templates.TemplateResponse( + "error_file_not_found.html", + { + "request": request, + "folder_name": folder_name, + "file_name": file_name, + "full_path": target_meta, + "file_type": "meta.json", + }, + status_code=404, + ) + + +@router.get("/query_smpl/{folder_name:path}/{file_name}") +async def query_smpl(folder_name: str, file_name: str): + smpl_data = get_cached_smpl_frames(folder_name, file_name) + return JSONResponse(content=smpl_data) + + +@router.get("/query_caption/{folder_name:path}/{file_name}") +async def query_caption(request: Request, folder_name: str, file_name: str): + hide_captions = request.query_params.get("hide_captions", "0") == "1" + if hide_captions: + return JSONResponse({"result": []}) + captions = get_cached_captions(folder_name, file_name) + return JSONResponse({"result": captions}) diff --git a/stats/Mean.npy b/stats/Mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..f6f59ef8e6e2e553a961c0684323040216c11090 --- /dev/null +++ b/stats/Mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4bbe5f68881fe793d509f262ad3dc698a91f79158a9bde9d6f24d46ef9652bd +size 1736 diff --git a/stats/Std.npy b/stats/Std.npy new file mode 100644 index 0000000000000000000000000000000000000000..c90b61ed66b47cffac083ee7248bbdb9b668955a --- /dev/null +++ b/stats/Std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cc86d1ca4b47fbd0e84b677c4397e51b6615abcb34e38fdd9daabbe6a5685cc +size 1736