chingshuai
commited on
Commit
·
fa80dfd
1
Parent(s):
c4d5f5a
merge gradio_app, runtime
Browse files- gradio_app.py +320 -133
- hymotion/network/text_encoders/text_encoder.py +5 -3
- hymotion/pipeline/motion_diffusion.py +6 -13
- hymotion/prompt_engineering/client.py +88 -0
- hymotion/utils/gradio_css.py +29 -0
- hymotion/utils/gradio_runtime.py +14 -17
- hymotion/utils/t2m_runtime.py +2 -2
- requirements.txt +2 -2
gradio_app.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import List, Optional, Tuple, Union
|
|
| 11 |
import gradio as gr
|
| 12 |
from hymotion.utils.gradio_runtime import ModelInference
|
| 13 |
from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
|
| 14 |
-
from hymotion.utils.gradio_css import get_placeholder_html, APP_CSS, HEADER_BASE_MD, FOOTER_MD
|
| 15 |
# Import spaces for Hugging Face Zero GPU support
|
| 16 |
import spaces
|
| 17 |
|
|
@@ -20,6 +20,155 @@ DATA_SOURCES = {
|
|
| 20 |
"example_prompts": "examples/example_prompts/example_subset.json",
|
| 21 |
}
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
|
| 24 |
"""Load examples from txt file."""
|
| 25 |
|
|
@@ -69,19 +218,19 @@ def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12
|
|
| 69 |
|
| 70 |
return examples
|
| 71 |
|
|
|
|
| 72 |
@spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
|
| 73 |
def generate_motion_func(
|
| 74 |
# text input
|
| 75 |
original_text: str,
|
| 76 |
rewritten_text: str,
|
| 77 |
-
use_prompt_engineering: bool,
|
| 78 |
# model input
|
| 79 |
seed_input: str,
|
| 80 |
motion_duration: float,
|
| 81 |
cfg_scale: float,
|
| 82 |
-
# output
|
| 83 |
-
output_dir: str,
|
| 84 |
) -> Tuple[str, List[str]]:
|
|
|
|
|
|
|
| 85 |
# When rewrite is not available, use original_text directly
|
| 86 |
if use_prompt_engineering:
|
| 87 |
text_to_use = rewritten_text.strip()
|
|
@@ -106,7 +255,7 @@ def generate_motion_func(
|
|
| 106 |
cfg_scale=cfg_scale,
|
| 107 |
output_format=req_format,
|
| 108 |
original_text=original_text,
|
| 109 |
-
output_dir=output_dir
|
| 110 |
)
|
| 111 |
print(f"Running inference...after gpu_inference_wrapper")
|
| 112 |
# Escape HTML content for srcdoc attribute
|
|
@@ -128,12 +277,25 @@ def generate_motion_func(
|
|
| 128 |
[],
|
| 129 |
)
|
| 130 |
|
|
|
|
| 131 |
class T2MGradioUI:
|
| 132 |
def __init__(self, args):
|
| 133 |
self.output_dir = args.output_dir
|
| 134 |
print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
|
| 135 |
# self.args = args
|
| 136 |
self.prompt_engineering_available = args.use_prompt_engineering
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
self.all_example_data = {}
|
| 138 |
self._init_example_data()
|
| 139 |
|
|
@@ -162,34 +324,29 @@ class T2MGradioUI:
|
|
| 162 |
seeds = [random.randint(0, 999) for _ in range(4)]
|
| 163 |
return ",".join(map(str, seeds))
|
| 164 |
|
| 165 |
-
def _prompt_engineering(
|
| 166 |
-
self, text: str, duration: float, enable_rewrite: bool = True, enable_duration_est: bool = True
|
| 167 |
-
):
|
| 168 |
if not text.strip():
|
| 169 |
-
return "", gr.update(interactive=False), gr.update()
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
f"❌ Text rewriting/duration prediction failed: {str(e)}",
|
| 184 |
-
gr.update(interactive=False),
|
| 185 |
-
gr.update(),
|
| 186 |
-
)
|
| 187 |
-
if not enable_rewrite:
|
| 188 |
-
rewritten_text = text
|
| 189 |
-
if not enable_duration_est:
|
| 190 |
-
predicted_duration = duration
|
| 191 |
|
| 192 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
def _get_example_choices(self):
|
| 195 |
"""Get all example choices from all data sources"""
|
|
@@ -204,7 +361,10 @@ class T2MGradioUI:
|
|
| 204 |
def _on_example_select(self, selected_example):
|
| 205 |
"""When selecting an example, the callback function"""
|
| 206 |
if selected_example == "Custom Input":
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
| 208 |
else:
|
| 209 |
# find the corresponding example from all data sources
|
| 210 |
for source_name in self.all_example_data:
|
|
@@ -212,30 +372,45 @@ class T2MGradioUI:
|
|
| 212 |
for text, duration in example_data:
|
| 213 |
display_text = f"{text[:50]}..." if len(text) > 50 else text
|
| 214 |
if display_text == selected_example:
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def build_ui(self):
|
| 219 |
with gr.Blocks(css=APP_CSS) as demo:
|
| 220 |
# Create State components for non-UI values that need to be passed to event handlers
|
| 221 |
self.use_prompt_engineering_state = gr.State(self.prompt_engineering_available)
|
| 222 |
self.output_dir_state = gr.State(self.output_dir)
|
| 223 |
-
|
| 224 |
self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
|
| 225 |
|
| 226 |
with gr.Row():
|
| 227 |
# Left control panel
|
| 228 |
with gr.Column(scale=2, elem_classes=["left-panel"]):
|
|
|
|
| 229 |
# Input textbox
|
| 230 |
if self.prompt_engineering_available:
|
| 231 |
-
input_place_holder = "Enter text to generate motion, support Chinese and English text input."
|
| 232 |
else:
|
| 233 |
-
input_place_holder = "Enter text to generate motion, please use `A person ...` format to describe the motion"
|
| 234 |
|
| 235 |
self.text_input = gr.Textbox(
|
| 236 |
label="📝 Input Text",
|
| 237 |
placeholder=input_place_holder,
|
|
|
|
|
|
|
|
|
|
| 238 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
# Rewritten textbox
|
| 240 |
self.rewritten_text = gr.Textbox(
|
| 241 |
label="✏️ Rewritten Text",
|
|
@@ -281,18 +456,13 @@ class T2MGradioUI:
|
|
| 281 |
interactive=not self.prompt_engineering_available, # Enable directly if rewrite not available
|
| 282 |
)
|
| 283 |
|
| 284 |
-
if not self.prompt_engineering_available:
|
| 285 |
-
gr.Markdown(
|
| 286 |
-
"> ⚠️ **Prompt engineering is not available.** Text rewriting and duration estimation are disabled. Your input text and duration will be used directly."
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
|
| 290 |
# Example selection dropdown
|
| 291 |
self.example_dropdown = gr.Dropdown(
|
| 292 |
choices=self._get_example_choices(),
|
| 293 |
value="Custom Input",
|
| 294 |
-
label="📚
|
| 295 |
-
info="Select a preset example or input your own text above",
|
| 296 |
interactive=True,
|
| 297 |
)
|
| 298 |
|
|
@@ -309,6 +479,9 @@ class T2MGradioUI:
|
|
| 309 |
self.status_output = gr.Textbox(
|
| 310 |
label="📊 Status Information",
|
| 311 |
value=status_msg,
|
|
|
|
|
|
|
|
|
|
| 312 |
)
|
| 313 |
|
| 314 |
# FBX Download section
|
|
@@ -325,11 +498,27 @@ class T2MGradioUI:
|
|
| 325 |
# Right display area
|
| 326 |
with gr.Column(scale=3):
|
| 327 |
self.output_display = gr.HTML(
|
| 328 |
-
value=get_placeholder_html(),
|
| 329 |
-
show_label=False,
|
| 330 |
-
elem_classes=["flask-display"]
|
| 331 |
)
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Footer
|
| 334 |
gr.Markdown(FOOTER_MD, elem_classes=["footer"])
|
| 335 |
|
|
@@ -338,79 +527,73 @@ class T2MGradioUI:
|
|
| 338 |
return demo
|
| 339 |
|
| 340 |
def _build_advanced_settings(self):
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
label="Enable Text Rewriting",
|
| 348 |
-
value=True,
|
| 349 |
-
info="Automatically optimize text prompt to get better motion generation",
|
| 350 |
-
)
|
| 351 |
-
|
| 352 |
-
with gr.Group():
|
| 353 |
-
gr.Markdown("### ⏱️ Duration Settings")
|
| 354 |
-
self.enable_duration_est = gr.Checkbox(
|
| 355 |
-
label="Enable Duration Estimation",
|
| 356 |
-
value=True,
|
| 357 |
-
info="Automatically estimate the duration of the motion",
|
| 358 |
-
)
|
| 359 |
-
else:
|
| 360 |
-
# Create hidden placeholders with default values (disabled)
|
| 361 |
-
self.enable_rewrite = gr.Checkbox(
|
| 362 |
-
label="Enable Text Rewriting",
|
| 363 |
-
value=False,
|
| 364 |
-
visible=False,
|
| 365 |
)
|
| 366 |
-
self.
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
| 370 |
)
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
placeholder="Enter comma separated seed list (e.g.: 0,1,2,3)",
|
| 386 |
-
info="Random seeds control the diversity of generated motions",
|
| 387 |
-
)
|
| 388 |
-
with gr.Column(scale=1, min_width=60, elem_classes=["dice-container"]):
|
| 389 |
-
self.dice_btn = gr.Button(
|
| 390 |
-
"🎲 Lucky Button",
|
| 391 |
-
variant="secondary",
|
| 392 |
-
size="sm",
|
| 393 |
-
elem_classes=["dice-button"],
|
| 394 |
-
)
|
| 395 |
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
)
|
| 404 |
|
| 405 |
def _bind_events(self):
|
| 406 |
# Generate random seeds
|
| 407 |
self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
# Bind example selection event
|
| 410 |
self.example_dropdown.change(
|
| 411 |
fn=self._on_example_select,
|
| 412 |
inputs=[self.example_dropdown],
|
| 413 |
-
outputs=[self.text_input, self.seed_input, self.duration_slider],
|
| 414 |
)
|
| 415 |
|
| 416 |
# Rewrite text logic (only bind when rewrite is available)
|
|
@@ -420,16 +603,11 @@ class T2MGradioUI:
|
|
| 420 |
inputs=[
|
| 421 |
self.text_input,
|
| 422 |
self.duration_slider,
|
| 423 |
-
self.enable_rewrite,
|
| 424 |
-
self.enable_duration_est,
|
| 425 |
],
|
| 426 |
-
outputs=[self.rewritten_text, self.generate_btn, self.duration_slider],
|
| 427 |
).then(
|
| 428 |
-
fn=lambda: (
|
| 429 |
-
|
| 430 |
-
"Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
|
| 431 |
-
),
|
| 432 |
-
outputs=[self.rewritten_text, self.status_output],
|
| 433 |
)
|
| 434 |
|
| 435 |
# Generate motion logic
|
|
@@ -438,16 +616,8 @@ class T2MGradioUI:
|
|
| 438 |
outputs=[self.status_output],
|
| 439 |
).then(
|
| 440 |
generate_motion_func,
|
| 441 |
-
inputs=[
|
| 442 |
-
|
| 443 |
-
self.rewritten_text,
|
| 444 |
-
self.use_prompt_engineering_state,
|
| 445 |
-
self.seed_input,
|
| 446 |
-
self.duration_slider,
|
| 447 |
-
self.cfg_slider,
|
| 448 |
-
self.output_dir_state,
|
| 449 |
-
],
|
| 450 |
-
outputs=[self.output_display, self.fbx_files]
|
| 451 |
).then(
|
| 452 |
fn=lambda fbx_list: (
|
| 453 |
(
|
|
@@ -463,12 +633,22 @@ class T2MGradioUI:
|
|
| 463 |
|
| 464 |
# Reset logic - different behavior based on rewrite availability
|
| 465 |
if self.prompt_engineering_available:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
self.text_input.change(
|
| 467 |
-
fn=lambda: (
|
| 468 |
-
gr.update(visible=False),
|
| 469 |
-
gr.update(interactive=False),
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
),
|
|
|
|
| 472 |
outputs=[self.rewritten_text, self.generate_btn, self.status_output],
|
| 473 |
)
|
| 474 |
else:
|
|
@@ -508,11 +688,8 @@ def create_demo(final_model_path):
|
|
| 508 |
class Args:
|
| 509 |
model_path = final_model_path
|
| 510 |
output_dir = "output/gradio"
|
| 511 |
-
use_prompt_engineering =
|
| 512 |
use_text_encoder = True
|
| 513 |
-
prompt_engineering_host = os.environ.get("PROMPT_HOST", None)
|
| 514 |
-
prompt_engineering_model_path = os.environ.get("PROMPT_MODEL_PATH", None)
|
| 515 |
-
disable_prompt_engineering = os.environ.get("DISABLE_PROMPT_ENGINEERING", False)
|
| 516 |
|
| 517 |
args = Args()
|
| 518 |
|
|
@@ -538,11 +715,21 @@ def create_demo(final_model_path):
|
|
| 538 |
|
| 539 |
if __name__ == "__main__":
|
| 540 |
# Create demo at module level for Hugging Face Spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
try_to_download_text_encoder()
|
| 542 |
# Then download the main model
|
| 543 |
final_model_path = try_to_download_model()
|
| 544 |
-
model_inference = ModelInference(final_model_path,
|
| 545 |
use_prompt_engineering=False, use_text_encoder=True)
|
| 546 |
model_inference.initialize_model(device="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
demo = create_demo(final_model_path)
|
| 548 |
-
demo.launch(server_name="0.0.0.0")
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
from hymotion.utils.gradio_runtime import ModelInference
|
| 13 |
from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
|
| 14 |
+
from hymotion.utils.gradio_css import get_placeholder_html, APP_CSS, HEADER_BASE_MD, FOOTER_MD, WITHOUT_PROMPT_ENGINEERING_WARNING
|
| 15 |
# Import spaces for Hugging Face Zero GPU support
|
| 16 |
import spaces
|
| 17 |
|
|
|
|
| 20 |
"example_prompts": "examples/example_prompts/example_subset.json",
|
| 21 |
}
|
| 22 |
|
| 23 |
+
# Pre-generated examples for gallery display (generated on first startup)
|
| 24 |
+
# Add/remove items to control the number of examples
|
| 25 |
+
EXAMPLE_GALLERY_LIST = [
|
| 26 |
+
{
|
| 27 |
+
"prompt": "A person jumps upward with both legs twice.",
|
| 28 |
+
"duration": 4.5,
|
| 29 |
+
"seeds": "792",
|
| 30 |
+
"cfg_scale": 5.0,
|
| 31 |
+
"filename": "jump_twice",
|
| 32 |
+
},
|
| 33 |
+
# Add more examples here as needed:
|
| 34 |
+
{
|
| 35 |
+
"prompt": "A person jumps on their right leg.",
|
| 36 |
+
"duration": 4.5,
|
| 37 |
+
"seeds": "941",
|
| 38 |
+
"cfg_scale": 5.0,
|
| 39 |
+
"filename": "jump_right_leg",
|
| 40 |
+
},
|
| 41 |
+
]
|
| 42 |
+
EXAMPLE_GALLERY_OUTPUT_DIR = "examples/pregenerated"
|
| 43 |
+
|
| 44 |
+
def ensure_examples_generated(model_inference_obj) -> List[str]:
|
| 45 |
+
"""
|
| 46 |
+
Ensure all example motions are generated on first startup.
|
| 47 |
+
Returns a list of successfully generated example filenames.
|
| 48 |
+
"""
|
| 49 |
+
example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
|
| 50 |
+
os.makedirs(example_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
generated_examples = []
|
| 53 |
+
|
| 54 |
+
for example in EXAMPLE_GALLERY_LIST:
|
| 55 |
+
example_filename = example["filename"]
|
| 56 |
+
meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
|
| 57 |
+
|
| 58 |
+
# Check if already generated
|
| 59 |
+
if os.path.exists(meta_path):
|
| 60 |
+
print(f">>> Example already exists: {meta_path}")
|
| 61 |
+
generated_examples.append(example_filename)
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# Generate the example
|
| 65 |
+
print(f">>> Generating example motion: {example['prompt']}")
|
| 66 |
+
try:
|
| 67 |
+
html_content, fbx_files = model_inference_obj.run_inference(
|
| 68 |
+
text=example["prompt"],
|
| 69 |
+
seeds_csv=example["seeds"],
|
| 70 |
+
motion_duration=example["duration"],
|
| 71 |
+
cfg_scale=example["cfg_scale"],
|
| 72 |
+
output_format="dict", # Don't generate FBX for example
|
| 73 |
+
original_text=example["prompt"],
|
| 74 |
+
output_dir=example_dir,
|
| 75 |
+
output_filename=example_filename,
|
| 76 |
+
)
|
| 77 |
+
print(f">>> Example '{example_filename}' generated successfully!")
|
| 78 |
+
generated_examples.append(example_filename)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f">>> Failed to generate example '{example_filename}': {e}")
|
| 81 |
+
|
| 82 |
+
return generated_examples
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_example_gallery_html(example_index: int = 0) -> str:
|
| 86 |
+
"""
|
| 87 |
+
Load a specific pre-generated example and return iframe HTML for display.
|
| 88 |
+
Args:
|
| 89 |
+
example_index: Index of the example in EXAMPLE_GALLERY_LIST
|
| 90 |
+
"""
|
| 91 |
+
from hymotion.utils.visualize_mesh_web import generate_static_html_content
|
| 92 |
+
|
| 93 |
+
if example_index < 0 or example_index >= len(EXAMPLE_GALLERY_LIST):
|
| 94 |
+
return ""
|
| 95 |
+
|
| 96 |
+
example = EXAMPLE_GALLERY_LIST[example_index]
|
| 97 |
+
example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
|
| 98 |
+
example_filename = example["filename"]
|
| 99 |
+
meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
|
| 100 |
+
|
| 101 |
+
if not os.path.exists(meta_path):
|
| 102 |
+
return f"""
|
| 103 |
+
<div style='height: 300px; display: flex; justify-content: center; align-items: center;
|
| 104 |
+
background: #2d3748; border-radius: 12px; color: #a0aec0;'>
|
| 105 |
+
<p>Example not generated yet. Please restart the app.</p>
|
| 106 |
+
</div>
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
html_content = generate_static_html_content(
|
| 111 |
+
folder_name=example_dir,
|
| 112 |
+
file_name=example_filename,
|
| 113 |
+
hide_captions=False,
|
| 114 |
+
)
|
| 115 |
+
escaped_html = html_content.replace('"', """)
|
| 116 |
+
iframe_html = f"""
|
| 117 |
+
<iframe
|
| 118 |
+
srcdoc="{escaped_html}"
|
| 119 |
+
width="100%"
|
| 120 |
+
height="350px"
|
| 121 |
+
style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
|
| 122 |
+
></iframe>
|
| 123 |
+
"""
|
| 124 |
+
return iframe_html
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f">>> Failed to load example gallery: {e}")
|
| 127 |
+
return ""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_example_gallery_grid_html() -> str:
|
| 131 |
+
"""
|
| 132 |
+
Generate a grid layout HTML for all examples in the gallery.
|
| 133 |
+
"""
|
| 134 |
+
if not EXAMPLE_GALLERY_LIST:
|
| 135 |
+
return "<p>No examples configured.</p>"
|
| 136 |
+
|
| 137 |
+
# Calculate grid columns based on number of examples
|
| 138 |
+
num_examples = len(EXAMPLE_GALLERY_LIST)
|
| 139 |
+
if num_examples == 1:
|
| 140 |
+
columns = 1
|
| 141 |
+
elif num_examples == 2:
|
| 142 |
+
columns = 2
|
| 143 |
+
elif num_examples <= 4:
|
| 144 |
+
columns = 2
|
| 145 |
+
else:
|
| 146 |
+
columns = 3
|
| 147 |
+
|
| 148 |
+
grid_items = []
|
| 149 |
+
for idx, example in enumerate(EXAMPLE_GALLERY_LIST):
|
| 150 |
+
iframe_html = load_example_gallery_html(idx)
|
| 151 |
+
prompt_short = example["prompt"][:60] + "..." if len(example["prompt"]) > 60 else example["prompt"]
|
| 152 |
+
|
| 153 |
+
grid_items.append(f"""
|
| 154 |
+
<div class="example-grid-item" style="background: var(--card-bg, #fff); border-radius: 12px;
|
| 155 |
+
padding: 12px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
|
| 156 |
+
<div style="font-size: 14px; font-weight: 600; color: var(--text-primary, #333);
|
| 157 |
+
margin-bottom: 8px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap;">
|
| 158 |
+
{prompt_short}
|
| 159 |
+
</div>
|
| 160 |
+
{iframe_html}
|
| 161 |
+
</div>
|
| 162 |
+
""")
|
| 163 |
+
|
| 164 |
+
grid_html = f"""
|
| 165 |
+
<div style="display: grid; grid-template-columns: repeat({columns}, 1fr); gap: 16px; padding: 8px;">
|
| 166 |
+
{"".join(grid_items)}
|
| 167 |
+
</div>
|
| 168 |
+
"""
|
| 169 |
+
return grid_html
|
| 170 |
+
|
| 171 |
+
|
| 172 |
def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
|
| 173 |
"""Load examples from txt file."""
|
| 174 |
|
|
|
|
| 218 |
|
| 219 |
return examples
|
| 220 |
|
| 221 |
+
|
| 222 |
@spaces.GPU(duration=120) # Request GPU for up to 120 seconds per inference
|
| 223 |
def generate_motion_func(
|
| 224 |
# text input
|
| 225 |
original_text: str,
|
| 226 |
rewritten_text: str,
|
|
|
|
| 227 |
# model input
|
| 228 |
seed_input: str,
|
| 229 |
motion_duration: float,
|
| 230 |
cfg_scale: float,
|
|
|
|
|
|
|
| 231 |
) -> Tuple[str, List[str]]:
|
| 232 |
+
use_prompt_engineering = USE_PROMPT_ENGINEERING
|
| 233 |
+
output_dir = "output/gradio"
|
| 234 |
# When rewrite is not available, use original_text directly
|
| 235 |
if use_prompt_engineering:
|
| 236 |
text_to_use = rewritten_text.strip()
|
|
|
|
| 255 |
cfg_scale=cfg_scale,
|
| 256 |
output_format=req_format,
|
| 257 |
original_text=original_text,
|
| 258 |
+
output_dir=output_dir,
|
| 259 |
)
|
| 260 |
print(f"Running inference...after gpu_inference_wrapper")
|
| 261 |
# Escape HTML content for srcdoc attribute
|
|
|
|
| 277 |
[],
|
| 278 |
)
|
| 279 |
|
| 280 |
+
|
| 281 |
class T2MGradioUI:
|
| 282 |
def __init__(self, args):
|
| 283 |
self.output_dir = args.output_dir
|
| 284 |
print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
|
| 285 |
# self.args = args
|
| 286 |
self.prompt_engineering_available = args.use_prompt_engineering
|
| 287 |
+
if self.prompt_engineering_available:
|
| 288 |
+
try:
|
| 289 |
+
from hymotion.prompt_engineering.client import PromptEngineeringClient
|
| 290 |
+
self.prompt_engineering_client = PromptEngineeringClient()
|
| 291 |
+
# Test the client with a simple prompt to verify it works
|
| 292 |
+
self.prompt_engineering_client.rewrite_prompt_and_infer_time("A person walks forward.", max_timeout=30)
|
| 293 |
+
print(f"[{self.__class__.__name__}] Prompt engineering client initialized successfully.")
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"[{self.__class__.__name__}] Prompt engineering client initialization failed: {e}")
|
| 296 |
+
self.prompt_engineering_available = False
|
| 297 |
+
|
| 298 |
+
|
| 299 |
self.all_example_data = {}
|
| 300 |
self._init_example_data()
|
| 301 |
|
|
|
|
| 324 |
seeds = [random.randint(0, 999) for _ in range(4)]
|
| 325 |
return ",".join(map(str, seeds))
|
| 326 |
|
| 327 |
+
def _prompt_engineering(self, text: str, duration: float):
|
|
|
|
|
|
|
| 328 |
if not text.strip():
|
| 329 |
+
return "", gr.update(interactive=False), gr.update(), "⚠️ Please enter text first"
|
| 330 |
|
| 331 |
+
print(f"\t>>> Using LLM to estimate duration/rewrite text...")
|
| 332 |
+
try:
|
| 333 |
+
predicted_duration, rewritten_text = self.prompt_engineering_client.rewrite_prompt_and_infer_time(text=text)
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"\t>>> Text rewriting/duration prediction failed: {e}")
|
| 336 |
+
# On failure, use original text and enable generate button
|
| 337 |
+
return (
|
| 338 |
+
text, # Use original text as fallback
|
| 339 |
+
gr.update(interactive=True), # Enable generate button
|
| 340 |
+
gr.update(),
|
| 341 |
+
f"⚠️ Text rewriting failed: {str(e)}\n💡 Using your original input directly. You can click [🚀 Generate Motion] to continue.",
|
| 342 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
+
return (
|
| 345 |
+
rewritten_text,
|
| 346 |
+
gr.update(interactive=True),
|
| 347 |
+
gr.update(value=predicted_duration),
|
| 348 |
+
"✅ Text rewriting completed! Please check and edit the rewritten text, then click [🚀 Generate Motion]",
|
| 349 |
+
)
|
| 350 |
|
| 351 |
def _get_example_choices(self):
|
| 352 |
"""Get all example choices from all data sources"""
|
|
|
|
| 361 |
def _on_example_select(self, selected_example):
|
| 362 |
"""When selecting an example, the callback function"""
|
| 363 |
if selected_example == "Custom Input":
|
| 364 |
+
if self.prompt_engineering_available:
|
| 365 |
+
return "", self._generate_random_seeds(), gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please enter text or select an example"
|
| 366 |
+
else:
|
| 367 |
+
return "", self._generate_random_seeds(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 368 |
else:
|
| 369 |
# find the corresponding example from all data sources
|
| 370 |
for source_name in self.all_example_data:
|
|
|
|
| 372 |
for text, duration in example_data:
|
| 373 |
display_text = f"{text[:50]}..." if len(text) > 50 else text
|
| 374 |
if display_text == selected_example:
|
| 375 |
+
if self.prompt_engineering_available:
|
| 376 |
+
# Set text directly to rewritten_text and enable generate button
|
| 377 |
+
return text, self._generate_random_seeds(), gr.update(value=duration), gr.update(value=text, visible=True), gr.update(interactive=True), "✅ Example selected! Click [🚀 Generate Motion] to start."
|
| 378 |
+
else:
|
| 379 |
+
return text, self._generate_random_seeds(), gr.update(value=duration), gr.update(), gr.update(), gr.update()
|
| 380 |
+
if self.prompt_engineering_available:
|
| 381 |
+
return "", self._generate_random_seeds(), gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please enter text or select an example"
|
| 382 |
+
else:
|
| 383 |
+
return "", self._generate_random_seeds(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 384 |
|
| 385 |
def build_ui(self):
|
| 386 |
with gr.Blocks(css=APP_CSS) as demo:
|
| 387 |
# Create State components for non-UI values that need to be passed to event handlers
|
| 388 |
self.use_prompt_engineering_state = gr.State(self.prompt_engineering_available)
|
| 389 |
self.output_dir_state = gr.State(self.output_dir)
|
| 390 |
+
|
| 391 |
self.header_md = gr.Markdown(HEADER_BASE_MD, elem_classes=["main-header"])
|
| 392 |
|
| 393 |
with gr.Row():
|
| 394 |
# Left control panel
|
| 395 |
with gr.Column(scale=2, elem_classes=["left-panel"]):
|
| 396 |
+
|
| 397 |
# Input textbox
|
| 398 |
if self.prompt_engineering_available:
|
| 399 |
+
input_place_holder = "Enter text to generate motion, support Chinese and English text input. Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported. Click [ 📚 Example Prompts ] to see more examples."
|
| 400 |
else:
|
| 401 |
+
input_place_holder = "Enter English text to generate motion, please use `A person ...` format to describe the motion, better less than 50 words. Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported. Click [ 📚 Example Prompts ] to see more examples."
|
| 402 |
|
| 403 |
self.text_input = gr.Textbox(
|
| 404 |
label="📝 Input Text",
|
| 405 |
placeholder=input_place_holder,
|
| 406 |
+
lines=3,
|
| 407 |
+
max_lines=10,
|
| 408 |
+
autoscroll=False,
|
| 409 |
)
|
| 410 |
+
# if not self.prompt_engineering_available:
|
| 411 |
+
# gr.Markdown(
|
| 412 |
+
# "Click [📚 Example Prompts] to see more examples."
|
| 413 |
+
# )
|
| 414 |
# Rewritten textbox
|
| 415 |
self.rewritten_text = gr.Textbox(
|
| 416 |
label="✏️ Rewritten Text",
|
|
|
|
| 456 |
interactive=not self.prompt_engineering_available, # Enable directly if rewrite not available
|
| 457 |
)
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
# Example selection dropdown
|
| 461 |
self.example_dropdown = gr.Dropdown(
|
| 462 |
choices=self._get_example_choices(),
|
| 463 |
value="Custom Input",
|
| 464 |
+
label="📚 Example Prompts",
|
| 465 |
+
# info="Select a preset example or input your own text above",
|
| 466 |
interactive=True,
|
| 467 |
)
|
| 468 |
|
|
|
|
| 479 |
self.status_output = gr.Textbox(
|
| 480 |
label="📊 Status Information",
|
| 481 |
value=status_msg,
|
| 482 |
+
lines=1,
|
| 483 |
+
max_lines=10,
|
| 484 |
+
elem_classes=["status-textbox"],
|
| 485 |
)
|
| 486 |
|
| 487 |
# FBX Download section
|
|
|
|
| 498 |
# Right display area
|
| 499 |
with gr.Column(scale=3):
|
| 500 |
self.output_display = gr.HTML(
|
| 501 |
+
value=get_placeholder_html(), show_label=False, elem_classes=["flask-display"]
|
|
|
|
|
|
|
| 502 |
)
|
| 503 |
|
| 504 |
+
# Example Gallery Section
|
| 505 |
+
with gr.Accordion("🎬 Example Gallery", open=True):
|
| 506 |
+
self.example_gallery_display = gr.HTML(
|
| 507 |
+
value=get_example_gallery_grid_html(),
|
| 508 |
+
show_label=False,
|
| 509 |
+
elem_classes=["example-gallery-display"]
|
| 510 |
+
)
|
| 511 |
+
# Create use example buttons for each example
|
| 512 |
+
with gr.Row():
|
| 513 |
+
self.use_example_btns = []
|
| 514 |
+
for idx, example in enumerate(EXAMPLE_GALLERY_LIST):
|
| 515 |
+
btn = gr.Button(
|
| 516 |
+
f"📋 Use Example {idx + 1}",
|
| 517 |
+
variant="secondary",
|
| 518 |
+
size="sm",
|
| 519 |
+
)
|
| 520 |
+
self.use_example_btns.append((btn, idx))
|
| 521 |
+
|
| 522 |
# Footer
|
| 523 |
gr.Markdown(FOOTER_MD, elem_classes=["footer"])
|
| 524 |
|
|
|
|
| 527 |
return demo
|
| 528 |
|
| 529 |
def _build_advanced_settings(self):
|
| 530 |
+
with gr.Row():
|
| 531 |
+
self.seed_input = gr.Textbox(
|
| 532 |
+
label="🎯 Random Seeds",
|
| 533 |
+
value="0,1,2,3",
|
| 534 |
+
placeholder="e.g.: 0,1,2,3",
|
| 535 |
+
scale=3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
)
|
| 537 |
+
self.dice_btn = gr.Button(
|
| 538 |
+
"🎲",
|
| 539 |
+
variant="secondary",
|
| 540 |
+
size="sm",
|
| 541 |
+
scale=1,
|
| 542 |
+
min_width=50,
|
| 543 |
)
|
| 544 |
+
self.cfg_slider = gr.Slider(
|
| 545 |
+
minimum=1,
|
| 546 |
+
maximum=10,
|
| 547 |
+
value=5.0,
|
| 548 |
+
step=0.1,
|
| 549 |
+
label="⚙️ CFG Strength",
|
| 550 |
+
)
|
| 551 |
|
| 552 |
+
def _on_use_example(self, example_idx: int):
|
| 553 |
+
"""When clicking 'Use This Example' button, fill in the example prompt"""
|
| 554 |
+
if example_idx < 0 or example_idx >= len(EXAMPLE_GALLERY_LIST):
|
| 555 |
+
if self.prompt_engineering_available:
|
| 556 |
+
return ("", "0,1,2,3", gr.update(), gr.update(value="", visible=False), gr.update(interactive=False), "Please select a valid example")
|
| 557 |
+
else:
|
| 558 |
+
return ("", "0,1,2,3", gr.update(), gr.update(), gr.update(), gr.update())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
|
| 560 |
+
example = EXAMPLE_GALLERY_LIST[example_idx]
|
| 561 |
+
if self.prompt_engineering_available:
|
| 562 |
+
# Set text directly to rewritten_text and enable generate button
|
| 563 |
+
return (
|
| 564 |
+
example["prompt"],
|
| 565 |
+
example["seeds"],
|
| 566 |
+
gr.update(value=example["duration"]),
|
| 567 |
+
gr.update(value=example["prompt"], visible=True),
|
| 568 |
+
gr.update(interactive=True),
|
| 569 |
+
"✅ Example selected! Click [🚀 Generate Motion] to start.",
|
| 570 |
+
)
|
| 571 |
+
else:
|
| 572 |
+
return (
|
| 573 |
+
example["prompt"],
|
| 574 |
+
example["seeds"],
|
| 575 |
+
gr.update(value=example["duration"]),
|
| 576 |
+
gr.update(),
|
| 577 |
+
gr.update(),
|
| 578 |
+
gr.update(),
|
| 579 |
)
|
| 580 |
|
| 581 |
def _bind_events(self):
|
| 582 |
# Generate random seeds
|
| 583 |
self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
|
| 584 |
|
| 585 |
+
# Use example buttons - bind each button to its example
|
| 586 |
+
for btn, idx in self.use_example_btns:
|
| 587 |
+
btn.click(
|
| 588 |
+
fn=lambda i=idx: self._on_use_example(i),
|
| 589 |
+
outputs=[self.text_input, self.seed_input, self.duration_slider, self.rewritten_text, self.generate_btn, self.status_output],
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
# Bind example selection event
|
| 593 |
self.example_dropdown.change(
|
| 594 |
fn=self._on_example_select,
|
| 595 |
inputs=[self.example_dropdown],
|
| 596 |
+
outputs=[self.text_input, self.seed_input, self.duration_slider, self.rewritten_text, self.generate_btn, self.status_output],
|
| 597 |
)
|
| 598 |
|
| 599 |
# Rewrite text logic (only bind when rewrite is available)
|
|
|
|
| 603 |
inputs=[
|
| 604 |
self.text_input,
|
| 605 |
self.duration_slider,
|
|
|
|
|
|
|
| 606 |
],
|
| 607 |
+
outputs=[self.rewritten_text, self.generate_btn, self.duration_slider, self.status_output],
|
| 608 |
).then(
|
| 609 |
+
fn=lambda: gr.update(visible=True),
|
| 610 |
+
outputs=[self.rewritten_text],
|
|
|
|
|
|
|
|
|
|
| 611 |
)
|
| 612 |
|
| 613 |
# Generate motion logic
|
|
|
|
| 616 |
outputs=[self.status_output],
|
| 617 |
).then(
|
| 618 |
generate_motion_func,
|
| 619 |
+
inputs=[self.text_input, self.rewritten_text, self.seed_input, self.duration_slider, self.cfg_slider],
|
| 620 |
+
outputs=[self.output_display, self.fbx_files],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
).then(
|
| 622 |
fn=lambda fbx_list: (
|
| 623 |
(
|
|
|
|
| 633 |
|
| 634 |
# Reset logic - different behavior based on rewrite availability
|
| 635 |
if self.prompt_engineering_available:
|
| 636 |
+
# When text_input changes:
|
| 637 |
+
# - If text_input == rewritten_text, it means the change was triggered by example selection,
|
| 638 |
+
# so we should NOT hide the rewritten_text (keep it visible and generate button enabled)
|
| 639 |
+
# - If text_input != rewritten_text, it means user manually edited the input,
|
| 640 |
+
# so we should hide the rewritten_text and require a new rewrite
|
| 641 |
self.text_input.change(
|
| 642 |
+
fn=lambda text, rewritten: (
|
| 643 |
+
gr.update() if text.strip() == rewritten.strip() else gr.update(visible=False),
|
| 644 |
+
gr.update() if text.strip() == rewritten.strip() else gr.update(interactive=False),
|
| 645 |
+
(
|
| 646 |
+
"✅ Example selected! Click [🚀 Generate Motion] to start."
|
| 647 |
+
if text.strip() == rewritten.strip() and text.strip()
|
| 648 |
+
else "Please click the [🔄 Rewrite Text] button to rewrite the text first"
|
| 649 |
+
),
|
| 650 |
),
|
| 651 |
+
inputs=[self.text_input, self.rewritten_text],
|
| 652 |
outputs=[self.rewritten_text, self.generate_btn, self.status_output],
|
| 653 |
)
|
| 654 |
else:
|
|
|
|
| 688 |
class Args:
|
| 689 |
model_path = final_model_path
|
| 690 |
output_dir = "output/gradio"
|
| 691 |
+
use_prompt_engineering = USE_PROMPT_ENGINEERING
|
| 692 |
use_text_encoder = True
|
|
|
|
|
|
|
|
|
|
| 693 |
|
| 694 |
args = Args()
|
| 695 |
|
|
|
|
| 715 |
|
| 716 |
if __name__ == "__main__":
|
| 717 |
# Create demo at module level for Hugging Face Spaces
|
| 718 |
+
import argparse
|
| 719 |
+
parser = argparse.ArgumentParser(description="HY-Motion-1.0 Gradio App")
|
| 720 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to listen on")
|
| 721 |
+
args = parser.parse_args()
|
| 722 |
+
|
| 723 |
+
USE_PROMPT_ENGINEERING = True
|
| 724 |
try_to_download_text_encoder()
|
| 725 |
# Then download the main model
|
| 726 |
final_model_path = try_to_download_model()
|
| 727 |
+
model_inference = ModelInference(final_model_path,
|
| 728 |
use_prompt_engineering=False, use_text_encoder=True)
|
| 729 |
model_inference.initialize_model(device="cpu")
|
| 730 |
+
|
| 731 |
+
# Generate examples on first startup (if not exists)
|
| 732 |
+
ensure_examples_generated(model_inference)
|
| 733 |
+
|
| 734 |
demo = create_demo(final_model_path)
|
| 735 |
+
demo.launch(server_name="0.0.0.0", server_port=args.port)
|
hymotion/network/text_encoders/text_encoder.py
CHANGED
|
@@ -99,7 +99,9 @@ class HYTextModel(nn.Module):
|
|
| 99 |
padding_side="right",
|
| 100 |
)
|
| 101 |
self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
|
| 102 |
-
LLM_ENCODER_LAYOUT[llm_type]["module_path"],
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
|
| 105 |
self.ctxt_dim = self.llm_text_encoder.config.hidden_size
|
|
@@ -150,9 +152,9 @@ class HYTextModel(nn.Module):
|
|
| 150 |
)
|
| 151 |
)
|
| 152 |
if self.llm_type == "qwen3":
|
| 153 |
-
ctxt_raw = llm_outputs.hidden_states[-1]
|
| 154 |
else:
|
| 155 |
-
ctxt_raw = llm_outputs.last_hidden_state
|
| 156 |
|
| 157 |
start = self.crop_start
|
| 158 |
end = start + self._orig_max_length_llm
|
|
|
|
| 99 |
padding_side="right",
|
| 100 |
)
|
| 101 |
self.llm_text_encoder = LLM_ENCODER_LAYOUT[llm_type]["text_encoder_class"].from_pretrained(
|
| 102 |
+
LLM_ENCODER_LAYOUT[llm_type]["module_path"],
|
| 103 |
+
low_cpu_mem_usage=True,
|
| 104 |
+
torch_dtype=torch.bfloat16,
|
| 105 |
)
|
| 106 |
self.llm_text_encoder = self.llm_text_encoder.eval().requires_grad_(False)
|
| 107 |
self.ctxt_dim = self.llm_text_encoder.config.hidden_size
|
|
|
|
| 152 |
)
|
| 153 |
)
|
| 154 |
if self.llm_type == "qwen3":
|
| 155 |
+
ctxt_raw = llm_outputs.hidden_states[-1].clone()
|
| 156 |
else:
|
| 157 |
+
ctxt_raw = llm_outputs.last_hidden_state.clone()
|
| 158 |
|
| 159 |
start = self.crop_start
|
| 160 |
end = start + self._orig_max_length_llm
|
hymotion/pipeline/motion_diffusion.py
CHANGED
|
@@ -176,7 +176,6 @@ class MotionGeneration(torch.nn.Module):
|
|
| 176 |
def load_in_demo(
|
| 177 |
self,
|
| 178 |
ckpt_name: str,
|
| 179 |
-
mean_std_name: Optional[str] = None,
|
| 180 |
build_text_encoder: bool = True,
|
| 181 |
allow_empty_ckpt: bool = False,
|
| 182 |
) -> None:
|
|
@@ -188,11 +187,6 @@ class MotionGeneration(torch.nn.Module):
|
|
| 188 |
else:
|
| 189 |
checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
|
| 190 |
self.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 191 |
-
if mean_std_name is not None:
|
| 192 |
-
assert os.path.exists(mean_std_name), f"{mean_std_name} not found"
|
| 193 |
-
if not os.path.isfile(mean_std_name):
|
| 194 |
-
mean_std_name = None
|
| 195 |
-
self._load_mean_std(mean_std_name)
|
| 196 |
self.motion_transformer.eval()
|
| 197 |
if build_text_encoder and not self.uncondition_mode:
|
| 198 |
self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
|
|
@@ -299,11 +293,11 @@ class MotionGeneration(torch.nn.Module):
|
|
| 299 |
k3d = torch.zeros(B, L, nj, 3, device=device)
|
| 300 |
|
| 301 |
return dict(
|
| 302 |
-
latent_denorm=latent_denorm, # (B, L, 201)
|
| 303 |
-
keypoints3d=k3d, # (B, L, J, 3)
|
| 304 |
-
rot6d=rot6d_smooth, # (B, L, J, 6)
|
| 305 |
-
transl=transl_smooth, # (B, L, 3)
|
| 306 |
-
root_rotations_mat=root_rotmat_smooth, # (B, L, 3, 3)
|
| 307 |
)
|
| 308 |
|
| 309 |
@staticmethod
|
|
@@ -584,9 +578,8 @@ class MotionFlowMatching(MotionGeneration):
|
|
| 584 |
)
|
| 585 |
with torch.no_grad():
|
| 586 |
trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
|
| 587 |
-
sampled = trajectory[-1]
|
| 588 |
assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
|
| 589 |
-
sampled = sampled[:, :length, ...].clone()
|
| 590 |
|
| 591 |
output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
|
| 592 |
|
|
|
|
| 176 |
def load_in_demo(
|
| 177 |
self,
|
| 178 |
ckpt_name: str,
|
|
|
|
| 179 |
build_text_encoder: bool = True,
|
| 180 |
allow_empty_ckpt: bool = False,
|
| 181 |
) -> None:
|
|
|
|
| 187 |
else:
|
| 188 |
checkpoint = torch.load(ckpt_name, map_location="cpu", weights_only=False)
|
| 189 |
self.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
self.motion_transformer.eval()
|
| 191 |
if build_text_encoder and not self.uncondition_mode:
|
| 192 |
self.text_encoder = load_object(self._text_encoder_module, self._text_encoder_cfg)
|
|
|
|
| 293 |
k3d = torch.zeros(B, L, nj, 3, device=device)
|
| 294 |
|
| 295 |
return dict(
|
| 296 |
+
latent_denorm=latent_denorm.cpu().detach(), # (B, L, 201)
|
| 297 |
+
keypoints3d=k3d.cpu().detach(), # (B, L, J, 3)
|
| 298 |
+
rot6d=rot6d_smooth.cpu().detach(), # (B, L, J, 6)
|
| 299 |
+
transl=transl_smooth.cpu().detach(), # (B, L, 3)
|
| 300 |
+
root_rotations_mat=root_rotmat_smooth.cpu().detach(), # (B, L, 3, 3)
|
| 301 |
)
|
| 302 |
|
| 303 |
@staticmethod
|
|
|
|
| 578 |
)
|
| 579 |
with torch.no_grad():
|
| 580 |
trajectory = odeint(fn, y0, t, **self._noise_scheduler_cfg)
|
| 581 |
+
sampled = trajectory[-1][:, :length, ...].clone()
|
| 582 |
assert isinstance(sampled, Tensor), f"sampled must be a Tensor, but got {type(sampled)}"
|
|
|
|
| 583 |
|
| 584 |
output_dict = self.decode_motion_from_latent(sampled, should_apply_smooothing=True)
|
| 585 |
|
hymotion/prompt_engineering/client.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
PROMPT = """
|
| 8 |
+
# Role
|
| 9 |
+
You are an expert in 3D motion analysis, animation timing, and choreography. Your task is to analyze textual action descriptions to estimate execution time and standardize the language for motion generation systems.
|
| 10 |
+
|
| 11 |
+
# Task
|
| 12 |
+
Analyze the user-provided [Input Action] and generate a structured JSON response containing a duration estimate and a refined caption.
|
| 13 |
+
|
| 14 |
+
# Instructions
|
| 15 |
+
|
| 16 |
+
### 1. Duration Estimation (frame_count)
|
| 17 |
+
- Analyze the complexity, speed, and physical constraints of the described action.
|
| 18 |
+
- Estimate the time required to perform the action in a **smooth, natural, and realistic manner**.
|
| 19 |
+
- Calculate the total duration in frames based on a **30 fps** (frames per second) standard.
|
| 20 |
+
- Output strictly as an Integer.
|
| 21 |
+
|
| 22 |
+
### 2. Caption Refinement (short_caption)
|
| 23 |
+
- Generate a refined, grammatically correct version of the input description in **English**.
|
| 24 |
+
- **Strict Constraints**:
|
| 25 |
+
- You must **PRESERVE** the original sequence of events (chronological order).
|
| 26 |
+
- You must **RETAIN** all original spatial modifiers (e.g., "left," "upward," "quickly").
|
| 27 |
+
- **DO NOT** add new sub-actions or hallucinate details not present in the input.
|
| 28 |
+
- **DO NOT** delete any specific movements.
|
| 29 |
+
- The goal is to improve clarity and flow while maintaining 100% semantic fidelity to the original request.
|
| 30 |
+
|
| 31 |
+
### 3. Output Format
|
| 32 |
+
- Return **ONLY** a raw JSON object.
|
| 33 |
+
- Do not use Markdown formatting (i.e., do not use ```json ... ```).
|
| 34 |
+
- Ensure the JSON is valid and parsable.
|
| 35 |
+
|
| 36 |
+
# JSON Structure
|
| 37 |
+
{{
|
| 38 |
+
"duration": <Integer, frames at 30fps>,
|
| 39 |
+
"short_caption": "<String, the refined English description>"
|
| 40 |
+
}}
|
| 41 |
+
|
| 42 |
+
# Input
|
| 43 |
+
{}
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class PromptEngineeringClient:
|
| 48 |
+
def __init__(self):
|
| 49 |
+
BASE_URL = os.environ.get("PROMPT_ENGINEERING_BASE_URL", "http://IP:PORT/v1")
|
| 50 |
+
API_KEY = os.environ.get("PROMPT_ENGINEERING_API_KEY", "EMPTY")
|
| 51 |
+
MODEL_NAME = os.environ.get("PROMPT_ENGINEERING_MODEL_NAME", "")
|
| 52 |
+
client = OpenAI(
|
| 53 |
+
api_key=API_KEY,
|
| 54 |
+
base_url=BASE_URL
|
| 55 |
+
)
|
| 56 |
+
self.model_name = MODEL_NAME
|
| 57 |
+
self.client = client
|
| 58 |
+
|
| 59 |
+
def rewrite_prompt_and_infer_time(self, text, max_timeout=30):
|
| 60 |
+
start_time = time.time()
|
| 61 |
+
while True:
|
| 62 |
+
end_time = time.time()
|
| 63 |
+
if end_time - start_time > max_timeout:
|
| 64 |
+
raise Exception("Prompt rewriting timeout")
|
| 65 |
+
try:
|
| 66 |
+
chat_response = self.client.chat.completions.create(
|
| 67 |
+
model=self.model_name,
|
| 68 |
+
messages=[
|
| 69 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 70 |
+
{"role": "user", "content": PROMPT.format(text)},
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
chat_response = json.loads(chat_response.choices[0].message.content.strip())
|
| 74 |
+
duration = chat_response["duration"]
|
| 75 |
+
short_caption = chat_response["short_caption"]
|
| 76 |
+
pred_duration = min(12, max(1, int(duration) / 30))
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(e)
|
| 79 |
+
continue
|
| 80 |
+
else:
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
return pred_duration, short_caption
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
# python -m hymotion.prompt_engineering.client
|
| 87 |
+
client = PromptEngineeringClient()
|
| 88 |
+
print(client.rewrite_prompt_and_infer_time("A person jumps upward with both legs twice."))
|
hymotion/utils/gradio_css.py
CHANGED
|
@@ -116,6 +116,14 @@ APP_CSS = """
|
|
| 116 |
font-weight:500 !important;
|
| 117 |
}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
/* Button base class and variant */
|
| 120 |
.generate-button,.rewrite-button,.dice-button{
|
| 121 |
border:none !important; color:#fff !important; font-weight:600 !important;
|
|
@@ -206,6 +214,20 @@ APP_CSS = """
|
|
| 206 |
padding:10px !important;
|
| 207 |
color:var(--text-secondary, #666) !important;
|
| 208 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"""
|
| 210 |
|
| 211 |
HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground\n### *Tencent Hunyuan 3D Digital Human Team*"
|
|
@@ -248,3 +270,10 @@ def get_placeholder_html() -> str:
|
|
| 248 |
</div>
|
| 249 |
"""
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
font-weight:500 !important;
|
| 117 |
}
|
| 118 |
|
| 119 |
+
/* Status textbox - dynamic height based on content */
|
| 120 |
+
.status-textbox textarea{
|
| 121 |
+
height:auto !important;
|
| 122 |
+
min-height:2.5em !important;
|
| 123 |
+
resize:none !important;
|
| 124 |
+
overflow-y:hidden !important;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
/* Button base class and variant */
|
| 128 |
.generate-button,.rewrite-button,.dice-button{
|
| 129 |
border:none !important; color:#fff !important; font-weight:600 !important;
|
|
|
|
| 214 |
padding:10px !important;
|
| 215 |
color:var(--text-secondary, #666) !important;
|
| 216 |
}
|
| 217 |
+
|
| 218 |
+
/* Example Gallery Styles */
|
| 219 |
+
.example-gallery-display{
|
| 220 |
+
padding:0 !important; margin:12px 0 !important; border:none !important;
|
| 221 |
+
box-shadow:none !important; background:var(--iframe-bg) !important;
|
| 222 |
+
border-radius:10px !important; position:relative !important;
|
| 223 |
+
min-height:500px !important;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
.example-gallery-display iframe{
|
| 227 |
+
width:100% !important; min-height:500px !important;
|
| 228 |
+
border:none !important; border-radius:10px !important; display:block !important;
|
| 229 |
+
background:var(--iframe-bg) !important;
|
| 230 |
+
}
|
| 231 |
"""
|
| 232 |
|
| 233 |
HEADER_BASE_MD = "# HY-Motion-1.0: Text-to-Motion Playground\n### *Tencent Hunyuan 3D Digital Human Team*"
|
|
|
|
| 270 |
</div>
|
| 271 |
"""
|
| 272 |
|
| 273 |
+
|
| 274 |
+
WITHOUT_PROMPT_ENGINEERING_WARNING = """
|
| 275 |
+
<div style='color: #ff0000; font-weight: bold;'>
|
| 276 |
+
<p>Prompt engineering is not available. You should use `A person ...` format to describe the motion and manually adjust the duration. Click [📚 Example Prompts] to see more examples.</p>
|
| 277 |
+
<p>Non-humanoid Characters, Multi-person Interactions and Environment & Camera are not supported.</p>
|
| 278 |
+
</div>
|
| 279 |
+
"""
|
hymotion/utils/gradio_runtime.py
CHANGED
|
@@ -26,6 +26,7 @@ def _now():
|
|
| 26 |
ms = int((t - int(t)) * 1000)
|
| 27 |
return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
|
| 28 |
|
|
|
|
| 29 |
_MODEL_CACHE = None
|
| 30 |
|
| 31 |
|
|
@@ -37,19 +38,14 @@ class SimpleRuntime(torch.nn.Module):
|
|
| 37 |
# prompt engineering
|
| 38 |
if self.load_prompt_engineering:
|
| 39 |
print(f"[{self.__class__.__name__}] Loading prompt engineering...")
|
| 40 |
-
self.prompt_rewriter = PromptRewriter(
|
| 41 |
-
host=None, model_path=None, device="cpu"
|
| 42 |
-
)
|
| 43 |
else:
|
| 44 |
self.prompt_rewriter = None
|
| 45 |
# text encoder
|
| 46 |
if self.load_text_encoder:
|
| 47 |
print(f"[{self.__class__.__name__}] Loading text encoder...")
|
| 48 |
_text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
|
| 49 |
-
_text_encoder_cfg = {
|
| 50 |
-
"llm_type": "qwen3",
|
| 51 |
-
"max_length_llm": 128
|
| 52 |
-
}
|
| 53 |
text_encoder = load_object(_text_encoder_module, _text_encoder_cfg)
|
| 54 |
else:
|
| 55 |
text_encoder = None
|
|
@@ -66,7 +62,6 @@ class SimpleRuntime(torch.nn.Module):
|
|
| 66 |
print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}")
|
| 67 |
pipeline.load_in_demo(
|
| 68 |
os.path.join(os.path.dirname(config_path), ckpt_name),
|
| 69 |
-
"stats",
|
| 70 |
build_text_encoder=False,
|
| 71 |
allow_empty_ckpt=False,
|
| 72 |
)
|
|
@@ -87,7 +82,6 @@ class SimpleRuntime(torch.nn.Module):
|
|
| 87 |
self.fbx_converter = None
|
| 88 |
print(">>> FBX module not found. FBX export will be disabled.")
|
| 89 |
|
| 90 |
-
|
| 91 |
def _generate_html_content(
|
| 92 |
self,
|
| 93 |
timestamp: str,
|
|
@@ -128,7 +122,6 @@ class SimpleRuntime(torch.nn.Module):
|
|
| 128 |
# Return error HTML
|
| 129 |
return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
|
| 130 |
|
| 131 |
-
|
| 132 |
def _generate_fbx_files(
|
| 133 |
self,
|
| 134 |
visualization_data: dict,
|
|
@@ -247,6 +240,7 @@ class SimpleRuntime(torch.nn.Module):
|
|
| 247 |
else:
|
| 248 |
raise ValueError(f">>> Invalid output format: {output_format}")
|
| 249 |
|
|
|
|
| 250 |
class ModelInference:
|
| 251 |
"""
|
| 252 |
Handles model inference and data processing for Depth Anything 3.
|
|
@@ -288,7 +282,7 @@ class ModelInference:
|
|
| 288 |
config_path=os.path.join(self.model_path, "config.yml"),
|
| 289 |
ckpt_name="latest.ckpt",
|
| 290 |
load_prompt_engineering=self.use_prompt_engineering,
|
| 291 |
-
load_text_encoder=self.use_text_encoder
|
| 292 |
)
|
| 293 |
# Load to CPU first (faster, and allows reuse)
|
| 294 |
_MODEL_CACHE = _MODEL_CACHE.to("cpu")
|
|
@@ -306,9 +300,7 @@ class ModelInference:
|
|
| 306 |
|
| 307 |
return _MODEL_CACHE
|
| 308 |
|
| 309 |
-
def run_inference(
|
| 310 |
-
self, *args, **kwargs
|
| 311 |
-
):
|
| 312 |
"""
|
| 313 |
Run DepthAnything3 model inference on images.
|
| 314 |
Args:
|
|
@@ -333,7 +325,6 @@ class ModelInference:
|
|
| 333 |
# Initialize model if needed - get model instance (not stored in self)
|
| 334 |
model = self.initialize_model(device)
|
| 335 |
|
| 336 |
-
|
| 337 |
with torch.no_grad():
|
| 338 |
print(f"[{self.__class__.__name__}] Running inference with torch.no_grad")
|
| 339 |
html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs)
|
|
@@ -347,7 +338,13 @@ class ModelInference:
|
|
| 347 |
|
| 348 |
return html_content, fbx_files
|
| 349 |
|
|
|
|
| 350 |
if __name__ == "__main__":
|
| 351 |
# python -m hymotion.utils.gradio_runtime
|
| 352 |
-
runtime = SimpleRuntime(
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
ms = int((t - int(t)) * 1000)
|
| 27 |
return time.strftime("%Y%m%d_%H%M%S", time.localtime(t)) + f"{ms:03d}"
|
| 28 |
|
| 29 |
+
|
| 30 |
_MODEL_CACHE = None
|
| 31 |
|
| 32 |
|
|
|
|
| 38 |
# prompt engineering
|
| 39 |
if self.load_prompt_engineering:
|
| 40 |
print(f"[{self.__class__.__name__}] Loading prompt engineering...")
|
| 41 |
+
self.prompt_rewriter = PromptRewriter(host=None, model_path=None, device="cpu")
|
|
|
|
|
|
|
| 42 |
else:
|
| 43 |
self.prompt_rewriter = None
|
| 44 |
# text encoder
|
| 45 |
if self.load_text_encoder:
|
| 46 |
print(f"[{self.__class__.__name__}] Loading text encoder...")
|
| 47 |
_text_encoder_module = "hymotion/network/text_encoders/text_encoder.HYTextModel"
|
| 48 |
+
_text_encoder_cfg = {"llm_type": "qwen3", "max_length_llm": 128}
|
|
|
|
|
|
|
|
|
|
| 49 |
text_encoder = load_object(_text_encoder_module, _text_encoder_cfg)
|
| 50 |
else:
|
| 51 |
text_encoder = None
|
|
|
|
| 62 |
print(f"[{self.__class__.__name__}] Loading ckpt: {ckpt_name}")
|
| 63 |
pipeline.load_in_demo(
|
| 64 |
os.path.join(os.path.dirname(config_path), ckpt_name),
|
|
|
|
| 65 |
build_text_encoder=False,
|
| 66 |
allow_empty_ckpt=False,
|
| 67 |
)
|
|
|
|
| 82 |
self.fbx_converter = None
|
| 83 |
print(">>> FBX module not found. FBX export will be disabled.")
|
| 84 |
|
|
|
|
| 85 |
def _generate_html_content(
|
| 86 |
self,
|
| 87 |
timestamp: str,
|
|
|
|
| 122 |
# Return error HTML
|
| 123 |
return f"<html><body><h1>Error generating visualization</h1><p>{str(e)}</p></body></html>"
|
| 124 |
|
|
|
|
| 125 |
def _generate_fbx_files(
|
| 126 |
self,
|
| 127 |
visualization_data: dict,
|
|
|
|
| 240 |
else:
|
| 241 |
raise ValueError(f">>> Invalid output format: {output_format}")
|
| 242 |
|
| 243 |
+
|
| 244 |
class ModelInference:
|
| 245 |
"""
|
| 246 |
Handles model inference and data processing for Depth Anything 3.
|
|
|
|
| 282 |
config_path=os.path.join(self.model_path, "config.yml"),
|
| 283 |
ckpt_name="latest.ckpt",
|
| 284 |
load_prompt_engineering=self.use_prompt_engineering,
|
| 285 |
+
load_text_encoder=self.use_text_encoder,
|
| 286 |
)
|
| 287 |
# Load to CPU first (faster, and allows reuse)
|
| 288 |
_MODEL_CACHE = _MODEL_CACHE.to("cpu")
|
|
|
|
| 300 |
|
| 301 |
return _MODEL_CACHE
|
| 302 |
|
| 303 |
+
def run_inference(self, *args, **kwargs):
|
|
|
|
|
|
|
| 304 |
"""
|
| 305 |
Run DepthAnything3 model inference on images.
|
| 306 |
Args:
|
|
|
|
| 325 |
# Initialize model if needed - get model instance (not stored in self)
|
| 326 |
model = self.initialize_model(device)
|
| 327 |
|
|
|
|
| 328 |
with torch.no_grad():
|
| 329 |
print(f"[{self.__class__.__name__}] Running inference with torch.no_grad")
|
| 330 |
html_content, fbx_files, model_output = model.generate_motion(*args, **kwargs)
|
|
|
|
| 338 |
|
| 339 |
return html_content, fbx_files
|
| 340 |
|
| 341 |
+
|
| 342 |
if __name__ == "__main__":
|
| 343 |
# python -m hymotion.utils.gradio_runtime
|
| 344 |
+
runtime = SimpleRuntime(
|
| 345 |
+
config_path="assets/config_simplified.yml",
|
| 346 |
+
ckpt_name="latest.ckpt",
|
| 347 |
+
load_prompt_engineering=False,
|
| 348 |
+
load_text_encoder=False,
|
| 349 |
+
)
|
| 350 |
+
print(runtime.pipeline)
|
hymotion/utils/t2m_runtime.py
CHANGED
|
@@ -128,7 +128,6 @@ class T2MRuntime:
|
|
| 128 |
device = torch.device("cpu")
|
| 129 |
pipeline.load_in_demo(
|
| 130 |
self.ckpt_name,
|
| 131 |
-
os.path.dirname(self.ckpt_name),
|
| 132 |
build_text_encoder=not self.skip_text,
|
| 133 |
allow_empty_ckpt=allow_empty_ckpt,
|
| 134 |
)
|
|
@@ -145,7 +144,6 @@ class T2MRuntime:
|
|
| 145 |
)
|
| 146 |
p.load_in_demo(
|
| 147 |
self.ckpt_name,
|
| 148 |
-
os.path.dirname(self.ckpt_name),
|
| 149 |
build_text_encoder=not self.skip_text,
|
| 150 |
allow_empty_ckpt=allow_empty_ckpt,
|
| 151 |
)
|
|
@@ -238,6 +236,8 @@ class T2MRuntime:
|
|
| 238 |
raise
|
| 239 |
finally:
|
| 240 |
self._release_pipeline(pi)
|
|
|
|
|
|
|
| 241 |
|
| 242 |
def load_text_encoder(self) -> None:
|
| 243 |
"""
|
|
|
|
| 128 |
device = torch.device("cpu")
|
| 129 |
pipeline.load_in_demo(
|
| 130 |
self.ckpt_name,
|
|
|
|
| 131 |
build_text_encoder=not self.skip_text,
|
| 132 |
allow_empty_ckpt=allow_empty_ckpt,
|
| 133 |
)
|
|
|
|
| 144 |
)
|
| 145 |
p.load_in_demo(
|
| 146 |
self.ckpt_name,
|
|
|
|
| 147 |
build_text_encoder=not self.skip_text,
|
| 148 |
allow_empty_ckpt=allow_empty_ckpt,
|
| 149 |
)
|
|
|
|
| 236 |
raise
|
| 237 |
finally:
|
| 238 |
self._release_pipeline(pi)
|
| 239 |
+
if torch.cuda.is_available():
|
| 240 |
+
torch.cuda.empty_cache()
|
| 241 |
|
| 242 |
def load_text_encoder(self) -> None:
|
| 243 |
"""
|
requirements.txt
CHANGED
|
@@ -3,11 +3,13 @@ huggingface_hub==0.30.0
|
|
| 3 |
|
| 4 |
torch==2.5.1
|
| 5 |
torchvision==0.20.1
|
|
|
|
| 6 |
accelerate==0.30.1
|
| 7 |
diffusers==0.26.3
|
| 8 |
transformers==4.53.3
|
| 9 |
einops==0.8.1
|
| 10 |
safetensors==0.5.3
|
|
|
|
| 11 |
|
| 12 |
numpy>=1.24.0,<2.0
|
| 13 |
scipy>=1.10.0
|
|
@@ -20,5 +22,3 @@ requests==2.32.4
|
|
| 20 |
openai==1.78.1
|
| 21 |
|
| 22 |
fbxsdkpy==2020.1.post2
|
| 23 |
-
|
| 24 |
-
torchdiffeq==0.2.5
|
|
|
|
| 3 |
|
| 4 |
torch==2.5.1
|
| 5 |
torchvision==0.20.1
|
| 6 |
+
torchdiffeq==0.2.5
|
| 7 |
accelerate==0.30.1
|
| 8 |
diffusers==0.26.3
|
| 9 |
transformers==4.53.3
|
| 10 |
einops==0.8.1
|
| 11 |
safetensors==0.5.3
|
| 12 |
+
bitsandbytes==0.49.0
|
| 13 |
|
| 14 |
numpy>=1.24.0,<2.0
|
| 15 |
scipy>=1.10.0
|
|
|
|
| 22 |
openai==1.78.1
|
| 23 |
|
| 24 |
fbxsdkpy==2020.1.post2
|
|
|
|
|
|