HY-Motion-1.0 / app.py
seawolf2357's picture
Update app.py
7867fd1 verified
import argparse
import codecs as cs
import json
import os
import os.path as osp
import random
import re
import textwrap
import requests
from typing import List, Optional, Tuple, Union
import gradio as gr
from hymotion.utils.gradio_runtime import ModelInference
from hymotion.utils.gradio_utils import try_to_download_model, try_to_download_text_encoder
from hymotion.utils.visualize_mesh_web import generate_static_html_content
# Import spaces for Hugging Face Zero GPU support
import spaces
# ============================================
# 🎨 Fireworks AI LLM Configuration
# ============================================
FIREWORKS_API_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
FIREWORKS_API_KEY = os.environ.get("FIREWORKS_API_KEY", "<API_KEY>")
FIREWORKS_MODEL = "accounts/fireworks/models/gpt-oss-120b"
# ============================================
# 🎨 Comic Classic Theme CSS
# ============================================
APP_CSS = """
/* ===== 🎨 Google Fonts Import ===== */
@import url('https://fonts.googleapis.com/css2?family=Bangers&family=Comic+Neue:wght@400;700&display=swap');
/* ===== 🎨 Comic Classic Background ===== */
.gradio-container {
background-color: #FEF9C3 !important;
background-image:
radial-gradient(#1F2937 1px, transparent 1px) !important;
background-size: 20px 20px !important;
min-height: 100vh !important;
font-family: 'Comic Neue', cursive, sans-serif !important;
}
/* ===== Hide HuggingFace Header ===== */
.huggingface-space-header,
#space-header,
.space-header,
[class*="space-header"],
.svelte-1ed2p3z,
.space-header-badge,
.header-badge,
[data-testid="space-header"],
.svelte-kqij2n,
.svelte-1ax1toq,
.embed-container > div:first-child {
display: none !important;
visibility: hidden !important;
height: 0 !important;
width: 0 !important;
overflow: hidden !important;
opacity: 0 !important;
pointer-events: none !important;
}
/* ===== Hide Footer ===== */
footer,
.footer,
.gradio-container footer,
.built-with,
[class*="footer"],
.gradio-footer,
.main-footer,
div[class*="footer"],
.show-api,
.built-with-gradio,
a[href*="gradio.app"],
a[href*="huggingface.co/spaces"] {
display: none !important;
visibility: hidden !important;
height: 0 !important;
padding: 0 !important;
margin: 0 !important;
}
/* ===== Main Container ===== */
#col-container {
max-width: 1200px;
margin: 0 auto;
}
/* ===== 🎨 Header Title - Comic Style ===== */
.main-header h1 {
font-family: 'Bangers', cursive !important;
color: #1F2937 !important;
font-size: 3.5rem !important;
font-weight: 400 !important;
text-align: center !important;
margin-bottom: 0.5rem !important;
text-shadow:
4px 4px 0px #FACC15,
6px 6px 0px #1F2937 !important;
letter-spacing: 3px !important;
-webkit-text-stroke: 2px #1F2937 !important;
}
/* ===== 🎨 Subtitle ===== */
.subtitle {
text-align: center !important;
font-family: 'Comic Neue', cursive !important;
font-size: 1.2rem !important;
color: #1F2937 !important;
margin-bottom: 1.5rem !important;
font-weight: 700 !important;
}
/* ===== 🎨 Card/Panel - Comic Frame Style ===== */
.gr-panel,
.gr-box,
.gr-form,
.block,
.gr-group,
.left-panel,
.flask-display {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
box-shadow: 6px 6px 0px #1F2937 !important;
transition: all 0.2s ease !important;
}
.gr-panel:hover,
.block:hover {
transform: translate(-2px, -2px) !important;
box-shadow: 8px 8px 0px #1F2937 !important;
}
/* ===== 🎨 Input Fields (Textbox) ===== */
textarea,
input[type="text"],
input[type="number"] {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #1F2937 !important;
font-family: 'Comic Neue', cursive !important;
font-size: 1rem !important;
font-weight: 700 !important;
transition: all 0.2s ease !important;
}
textarea:focus,
input[type="text"]:focus,
input[type="number"]:focus {
border-color: #3B82F6 !important;
box-shadow: 4px 4px 0px #3B82F6 !important;
outline: none !important;
}
textarea::placeholder {
color: #9CA3AF !important;
font-weight: 400 !important;
}
/* ===== 🎨 Primary Button - Comic Blue ===== */
.gr-button-primary,
button.primary,
.gr-button.primary,
.generate-button {
background: #3B82F6 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #FFFFFF !important;
font-family: 'Bangers', cursive !important;
font-weight: 400 !important;
font-size: 1.3rem !important;
letter-spacing: 2px !important;
padding: 14px 28px !important;
box-shadow: 5px 5px 0px #1F2937 !important;
transition: all 0.1s ease !important;
text-shadow: 1px 1px 0px #1F2937 !important;
}
.gr-button-primary:hover,
button.primary:hover,
.gr-button.primary:hover,
.generate-button:hover {
background: #2563EB !important;
transform: translate(-2px, -2px) !important;
box-shadow: 7px 7px 0px #1F2937 !important;
}
.gr-button-primary:active,
button.primary:active,
.gr-button.primary:active,
.generate-button:active {
transform: translate(3px, 3px) !important;
box-shadow: 2px 2px 0px #1F2937 !important;
}
/* ===== 🎨 Secondary Button - Comic Red ===== */
.gr-button-secondary,
button.secondary,
.rewrite-button {
background: #EF4444 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #FFFFFF !important;
font-family: 'Bangers', cursive !important;
font-weight: 400 !important;
font-size: 1.1rem !important;
letter-spacing: 1px !important;
box-shadow: 4px 4px 0px #1F2937 !important;
transition: all 0.1s ease !important;
text-shadow: 1px 1px 0px #1F2937 !important;
}
.gr-button-secondary:hover,
button.secondary:hover,
.rewrite-button:hover {
background: #DC2626 !important;
transform: translate(-2px, -2px) !important;
box-shadow: 6px 6px 0px #1F2937 !important;
}
.gr-button-secondary:active,
button.secondary:active,
.rewrite-button:active {
transform: translate(2px, 2px) !important;
box-shadow: 2px 2px 0px #1F2937 !important;
}
/* ===== 🎨 Status/Log Output Area ===== */
.status-textbox textarea {
background: #1F2937 !important;
color: #10B981 !important;
font-family: 'Courier New', monospace !important;
font-size: 0.9rem !important;
font-weight: 400 !important;
border: 3px solid #10B981 !important;
border-radius: 8px !important;
box-shadow: 4px 4px 0px #10B981 !important;
}
/* ===== 🎨 FBX Download Section - Comic Yellow ===== */
.fbx-download-section {
background: #FEF3C7 !important;
border: 4px dashed #F59E0B !important;
border-radius: 12px !important;
padding: 20px !important;
margin-top: 15px !important;
box-shadow: 6px 6px 0px #D97706 !important;
}
.fbx-download-section .gr-file {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
}
.fbx-download-title {
font-family: 'Bangers', cursive !important;
color: #D97706 !important;
font-size: 1.5rem !important;
text-align: center !important;
margin-bottom: 10px !important;
text-shadow: 2px 2px 0px #FEF3C7 !important;
}
/* ===== 🎨 Accordion - Speech Bubble Style ===== */
.gr-accordion {
background: #FACC15 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
box-shadow: 4px 4px 0px #1F2937 !important;
}
.gr-accordion-header {
color: #1F2937 !important;
font-family: 'Comic Neue', cursive !important;
font-weight: 700 !important;
font-size: 1.1rem !important;
}
/* ===== 🎨 Image/Display Output Area ===== */
.gr-image,
.image-container,
.flask-display {
border: 4px solid #1F2937 !important;
border-radius: 8px !important;
box-shadow: 8px 8px 0px #1F2937 !important;
overflow: hidden !important;
background: #FFFFFF !important;
}
/* ===== 🎨 Labels ===== */
label,
.gr-input-label,
.gr-block-label {
color: #1F2937 !important;
font-family: 'Comic Neue', cursive !important;
font-weight: 700 !important;
font-size: 1rem !important;
}
span.gr-label {
color: #1F2937 !important;
}
/* ===== 🎨 Slider ===== */
.gr-slider input[type="range"] {
accent-color: #3B82F6 !important;
}
/* ===== 🎨 Dropdown ===== */
.gr-dropdown {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
font-family: 'Comic Neue', cursive !important;
}
/* ===== 🎨 Example Gallery ===== */
.example-gallery-display {
background: #EFF6FF !important;
border: 3px solid #3B82F6 !important;
border-radius: 12px !important;
padding: 15px !important;
}
.example-grid-item {
background: #FFFFFF !important;
border: 3px solid #1F2937 !important;
border-radius: 12px !important;
box-shadow: 4px 4px 0px #1F2937 !important;
transition: all 0.2s ease !important;
}
.example-grid-item:hover {
transform: translate(-2px, -2px) !important;
box-shadow: 6px 6px 0px #1F2937 !important;
}
/* ===== 🎨 Scrollbar - Comic Style ===== */
::-webkit-scrollbar {
width: 12px;
height: 12px;
}
::-webkit-scrollbar-track {
background: #FEF9C3;
border: 2px solid #1F2937;
}
::-webkit-scrollbar-thumb {
background: #3B82F6;
border: 2px solid #1F2937;
border-radius: 0px;
}
::-webkit-scrollbar-thumb:hover {
background: #EF4444;
}
/* ===== 🎨 Selection Highlight ===== */
::selection {
background: #FACC15;
color: #1F2937;
}
/* ===== 🎨 Links ===== */
a {
color: #3B82F6 !important;
text-decoration: none !important;
font-weight: 700 !important;
}
a:hover {
color: #EF4444 !important;
}
/* ===== 🎨 Row/Column Spacing ===== */
.gr-row {
gap: 1.5rem !important;
}
.gr-column {
gap: 1rem !important;
}
/* ===== 🎨 Dice Button ===== */
.dice-btn {
background: #10B981 !important;
border: 3px solid #1F2937 !important;
border-radius: 8px !important;
color: #FFFFFF !important;
font-size: 1.5rem !important;
box-shadow: 3px 3px 0px #1F2937 !important;
transition: all 0.1s ease !important;
}
.dice-btn:hover {
background: #059669 !important;
transform: translate(-1px, -1px) !important;
box-shadow: 4px 4px 0px #1F2937 !important;
}
/* ===== Responsive Adjustments ===== */
@media (max-width: 768px) {
.main-header h1 {
font-size: 2.2rem !important;
text-shadow:
3px 3px 0px #FACC15,
4px 4px 0px #1F2937 !important;
}
.gr-button-primary,
button.primary {
padding: 12px 20px !important;
font-size: 1.1rem !important;
}
.gr-panel,
.block {
box-shadow: 4px 4px 0px #1F2937 !important;
}
}
/* ===== Disable Dark Mode ===== */
@media (prefers-color-scheme: dark) {
.gradio-container {
background-color: #FEF9C3 !important;
}
}
"""
# ============================================
# Header and Footer Markdown
# ============================================
HEADER_MD = """
# 🎬 HY-MOTION GENERATOR πŸ•Ί
"""
SUBTITLE_MD = """
<p class="subtitle">✨ Transform text into realistic 3D human motion! Type your description and watch the magic happen! ✨</p>
"""
FOOTER_MD = """
<div style="text-align: center; margin-top: 20px; padding: 15px; background: #FACC15; border: 3px solid #1F2937; border-radius: 8px; box-shadow: 4px 4px 0px #1F2937;">
<p style="font-family: 'Comic Neue', cursive; font-weight: 700; color: #1F2937; margin: 0;">
πŸš€ Powered by HY-Motion | πŸ€— Running on Hugging Face Spaces | πŸ”₯ LLM by Fireworks AI
</p>
</div>
"""
WITHOUT_PROMPT_ENGINEERING_WARNING = """
<div style="background: #FEF3C7; border: 3px solid #F59E0B; border-radius: 8px; padding: 10px; margin: 10px 0;">
<p style="color: #D97706; font-weight: 700; margin: 0;">
⚠️ Prompt Engineering is disabled. Enter English text directly in "A person..." format.
</p>
</div>
"""
# define data sources
DATA_SOURCES = {
"example_prompts": "examples/example_prompts/example_subset.json",
}
# Pre-generated examples for gallery display
EXAMPLE_GALLERY_LIST = [
{
"prompt": "A person jumps upward with both legs twice.",
"duration": 4.5,
"seeds": "792",
"cfg_scale": 5.0,
"filename": "jump_twice",
},
{
"prompt": "A person jumps on their right leg.",
"duration": 4.5,
"seeds": "941",
"cfg_scale": 5.0,
"filename": "jump_right_leg",
},
]
EXAMPLE_GALLERY_OUTPUT_DIR = "examples/pregenerated"
def get_placeholder_html():
"""Return placeholder HTML for the motion display area."""
return """
<div style='height: 700px; display: flex; justify-content: center; align-items: center;
background: linear-gradient(135deg, #EFF6FF 0%, #DBEAFE 100%);
border-radius: 12px; border: 4px dashed #3B82F6;'>
<div style='text-align: center;'>
<p style='font-family: Bangers, cursive; font-size: 2.5rem; color: #3B82F6;
text-shadow: 3px 3px 0px #FACC15; margin-bottom: 10px;'>
🎬 READY TO CREATE!
</p>
<p style='font-family: Comic Neue, cursive; font-size: 1.2rem; color: #1F2937; font-weight: 700;'>
Enter your motion description and click Generate!
</p>
</div>
</div>
"""
# ============================================
# Fireworks AI LLM Client
# ============================================
class FireworksLLMClient:
"""Client for Fireworks AI API to rewrite prompts and infer motion duration."""
def __init__(self, api_key: str = None, model: str = None):
self.api_key = api_key or FIREWORKS_API_KEY
self.model = model or FIREWORKS_MODEL
self.url = FIREWORKS_API_URL
def rewrite_prompt_and_infer_time(self, text: str, max_timeout: int = 60) -> Tuple[float, str]:
"""
Use Fireworks AI to rewrite the prompt and estimate motion duration.
Args:
text: Original motion description text
max_timeout: Maximum timeout for API request
Returns:
Tuple of (predicted_duration, rewritten_text)
"""
system_prompt = """You are a motion description expert. Your task is to:
1. Rewrite the user's motion description into clear, detailed English text suitable for motion generation.
2. Estimate the duration of the motion in seconds (between 0.5 and 12 seconds).
Rules for rewriting:
- Start with "A person..."
- Be specific about body parts, directions, and timing
- Keep it under 50 words
- Use present tense
Respond in JSON format only:
{"duration": <float>, "rewritten_text": "<string>"}
Example:
Input: "dance happily"
Output: {"duration": 5.0, "rewritten_text": "A person performs a joyful dance, swaying their hips side to side while raising both arms above their head and stepping lightly in place."}
"""
payload = {
"model": self.model,
"max_tokens": 4096,
"top_p": 1,
"top_k": 40,
"presence_penalty": 0,
"frequency_penalty": 0,
"temperature": 0.6,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Rewrite this motion description and estimate duration: {text}"}
]
}
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
response = requests.request(
"POST",
self.url,
headers=headers,
data=json.dumps(payload),
timeout=max_timeout
)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
# Parse JSON response
# Handle potential markdown code blocks
content = content.strip()
if content.startswith("```json"):
content = content[7:]
if content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
content = content.strip()
parsed = json.loads(content)
duration = float(parsed.get("duration", 5.0))
rewritten_text = parsed.get("rewritten_text", text)
# Clamp duration to valid range
duration = max(0.5, min(12.0, duration))
return duration, rewritten_text
except Exception as e:
print(f"Fireworks API error: {e}")
raise e
def ensure_examples_generated(model_inference_obj) -> List[str]:
"""
Ensure all example motions are generated on first startup.
Returns a list of successfully generated example filenames.
"""
example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
os.makedirs(example_dir, exist_ok=True)
generated_examples = []
for example in EXAMPLE_GALLERY_LIST:
example_filename = example["filename"]
meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
# Check if already generated
if os.path.exists(meta_path):
print(f">>> Example already exists: {meta_path}")
generated_examples.append(example_filename)
continue
# Generate the example
print(f">>> Generating example motion: {example['prompt']}")
try:
html_content, fbx_files = model_inference_obj.run_inference(
text=example["prompt"],
seeds_csv=example["seeds"],
motion_duration=example["duration"],
cfg_scale=example["cfg_scale"],
output_format="dict",
original_text=example["prompt"],
output_dir=example_dir,
output_filename=example_filename,
device="cpu",
)
print(f">>> Example '{example_filename}' generated successfully!")
generated_examples.append(example_filename)
except Exception as e:
print(f">>> Failed to generate example '{example_filename}': {e}")
return generated_examples
def load_example_gallery_html(example_index: int = 0) -> str:
"""
Load a specific pre-generated example and return iframe HTML for display.
"""
if example_index < 0 or example_index >= len(EXAMPLE_GALLERY_LIST):
return ""
example = EXAMPLE_GALLERY_LIST[example_index]
example_dir = EXAMPLE_GALLERY_OUTPUT_DIR
example_filename = example["filename"]
meta_path = os.path.join(example_dir, f"{example_filename}_meta.json")
if not os.path.exists(meta_path):
return f"""
<div style='height: 300px; display: flex; justify-content: center; align-items: center;
background: #FEF3C7; border: 3px dashed #F59E0B; border-radius: 12px; color: #D97706;'>
<p style='font-family: Comic Neue, cursive; font-weight: 700;'>
⚠️ Example not generated yet. Please restart the app.
</p>
</div>
"""
try:
html_content = generate_static_html_content(
folder_name=example_dir,
file_name=example_filename,
hide_captions=False,
)
escaped_html = html_content.replace('"', "&quot;")
iframe_html = f"""
<iframe
srcdoc="{escaped_html}"
width="100%"
height="350px"
style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
></iframe>
"""
return iframe_html
except Exception as e:
print(f">>> Failed to load example gallery: {e}")
return ""
def get_example_gallery_grid_html() -> str:
"""
Generate a grid layout HTML for all examples in the gallery.
"""
if not EXAMPLE_GALLERY_LIST:
return "<p>No examples configured.</p>"
num_examples = len(EXAMPLE_GALLERY_LIST)
columns = min(num_examples, 2)
grid_items = []
for idx, example in enumerate(EXAMPLE_GALLERY_LIST):
iframe_html = load_example_gallery_html(idx)
prompt_short = example["prompt"][:60] + "..." if len(example["prompt"]) > 60 else example["prompt"]
grid_items.append(f"""
<div class="example-grid-item" style="background: #FFFFFF; border-radius: 12px;
padding: 12px; border: 3px solid #1F2937; box-shadow: 4px 4px 0px #1F2937;">
<div style="font-family: 'Comic Neue', cursive; font-size: 14px; font-weight: 700; color: #1F2937;
margin-bottom: 8px; overflow: hidden; text-overflow: ellipsis; white-space: nowrap;">
πŸ“ {prompt_short}
</div>
{iframe_html}
</div>
""")
grid_html = f"""
<div style="display: grid; grid-template-columns: repeat({columns}, 1fr); gap: 16px; padding: 8px;">
{"".join(grid_items)}
</div>
"""
return grid_html
def load_examples_from_txt(txt_path: str, example_record_fps=20, max_duration=12):
"""Load examples from txt file."""
def _parse_line(line: str) -> Optional[Tuple[str, float]]:
line = line.strip()
if line and not line.startswith("#"):
parts = line.split("#")
if len(parts) >= 2:
text = parts[0].strip()
duration = int(parts[1]) / example_record_fps
duration = min(duration, max_duration)
else:
text = line.strip()
duration = 5.0
return text, duration
return None
examples: List[Tuple[str, float]] = []
if os.path.exists(txt_path):
try:
if txt_path.endswith(".txt"):
with cs.open(txt_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
result = _parse_line(line)
if result is None:
continue
text, duration = result
examples.append((text, duration))
elif txt_path.endswith(".json"):
with cs.open(txt_path, "r", encoding="utf-8") as f:
lines = json.load(f)
for key, value in lines.items():
if "_raw_chn" in key or "GENERATE_PROMPT_FORMAT" in key:
continue
for line in value:
result = _parse_line(line)
if result is None:
continue
text, duration = result
examples.append((text, duration))
print(f">>> Loaded {len(examples)} examples from {txt_path}")
except Exception as e:
print(f">>> Failed to load examples from {txt_path}: {e}")
else:
print(f">>> Examples file not found: {txt_path}")
return examples
@spaces.GPU(duration=120)
def generate_motion_func(
original_text: str,
rewritten_text: str,
seed_input: str,
motion_duration: float,
cfg_scale: float,
) -> Tuple[str, List[str]]:
"""Generate motion with GPU support."""
use_prompt_engineering = USE_PROMPT_ENGINEERING
output_dir = "output/gradio"
# Determine which text to use
if use_prompt_engineering and rewritten_text.strip():
text_to_use = rewritten_text.strip()
elif original_text.strip():
text_to_use = original_text.strip()
else:
return "❌ Error: Input text is empty, please enter text first!", []
try:
fbx_ok = model_inference.fbx_available
req_format = "fbx" if fbx_ok else "dict"
html_content, fbx_files = model_inference.run_inference(
text=text_to_use,
seeds_csv=seed_input,
motion_duration=motion_duration,
cfg_scale=cfg_scale,
output_format=req_format,
original_text=original_text,
output_dir=output_dir,
)
print(f"Running inference...completed")
escaped_html = html_content.replace('"', "&quot;")
iframe_html = f"""
<iframe
srcdoc="{escaped_html}"
width="100%"
height="750px"
style="border: none; border-radius: 12px; box-shadow: 0 4px 20px rgba(0,0,0,0.1);"
></iframe>
"""
return iframe_html, fbx_files
except Exception as e:
print(f"\t>>> Motion generation failed: {e}")
return (
f"""<div style='height: 700px; display: flex; justify-content: center; align-items: center;
background: #FEE2E2; border: 4px dashed #EF4444; border-radius: 12px;'>
<div style='text-align: center;'>
<p style='font-family: Bangers, cursive; font-size: 2rem; color: #EF4444;'>
❌ GENERATION FAILED!
</p>
<p style='font-family: Comic Neue, cursive; color: #1F2937; font-weight: 700;'>
{str(e)}
</p>
</div>
</div>""",
[],
)
class T2MGradioUI:
def __init__(self, args):
self.output_dir = args.output_dir
print(f"[{self.__class__.__name__}] output_dir: {self.output_dir}")
self.prompt_engineering_available = args.use_prompt_engineering
if self.prompt_engineering_available:
try:
self.llm_client = FireworksLLMClient()
# Test the client
self.llm_client.rewrite_prompt_and_infer_time("A person walks forward.", max_timeout=30)
print(f"[{self.__class__.__name__}] Fireworks LLM client initialized successfully.")
except Exception as e:
print(f"[{self.__class__.__name__}] Fireworks LLM client initialization failed: {e}")
self.prompt_engineering_available = False
global USE_PROMPT_ENGINEERING
USE_PROMPT_ENGINEERING = False
print(f"[{self.__class__.__name__}] USE_PROMPT_ENGINEERING set to False due to initialization failure")
self.all_example_data = {}
self._init_example_data()
def _init_example_data(self):
for source_name, file_path in DATA_SOURCES.items():
examples = load_examples_from_txt(file_path)
if examples:
self.all_example_data[source_name] = examples
else:
self.all_example_data[source_name] = [
("Twist at the waist and punch across the body.", 3.0),
("A person is running then takes big leap.", 3.0),
("A person holds a railing and walks down a set of stairs.", 5.0),
("A man performs a fluid and rhythmic hip-hop style dance.", 5.0),
]
print(f">>> Loaded data sources: {list(self.all_example_data.keys())}")
def _generate_random_seeds(self):
seeds = [random.randint(0, 999) for _ in range(4)]
return ",".join(map(str, seeds))
def _prompt_engineering(self, text: str, duration: float):
if not text.strip():
return "", gr.update(interactive=False), gr.update(), "⚠️ Please enter text first!"
print(f"\t>>> Using Fireworks LLM to rewrite text...")
try:
predicted_duration, rewritten_text = self.llm_client.rewrite_prompt_and_infer_time(text=text)
except Exception as e:
print(f"\t>>> Text rewriting failed: {e}")
return (
text,
gr.update(interactive=True),
gr.update(),
f"⚠️ LLM rewriting failed: {str(e)}\nπŸ’‘ Using your original input. Click [πŸš€ GENERATE!] to continue.",
)
return (
rewritten_text,
gr.update(interactive=True),
gr.update(value=predicted_duration),
"βœ… Text rewritten! Review and edit if needed, then click [πŸš€ GENERATE!]",
)
def _get_example_choices(self):
choices = ["Custom Input"]
for source_name in self.all_example_data:
example_data = self.all_example_data[source_name]
for text, _ in example_data:
display_text = f"{text[:50]}..." if len(text) > 50 else text
choices.append(display_text)
return choices
def _on_example_select(self, selected_example):
if selected_example == "Custom Input":
return (
"",
self._generate_random_seeds(),
gr.update(),
gr.update(value="", visible=False),
gr.update(interactive=True),
"πŸ“ Enter text or select an example"
)
else:
for source_name in self.all_example_data:
example_data = self.all_example_data[source_name]
for text, duration in example_data:
display_text = f"{text[:50]}..." if len(text) > 50 else text
if display_text == selected_example:
return (
text,
self._generate_random_seeds(),
gr.update(value=duration),
gr.update(value=text if self.prompt_engineering_available else "", visible=self.prompt_engineering_available),
gr.update(interactive=True),
"βœ… Example loaded! Click [πŸš€ GENERATE!] to create motion."
)
return (
"",
self._generate_random_seeds(),
gr.update(),
gr.update(value="", visible=False),
gr.update(interactive=True),
"πŸ“ Enter text or select an example"
)
def _on_use_example(self, example_idx: int):
if example_idx < 0 or example_idx >= len(EXAMPLE_GALLERY_LIST):
return (
"",
"0,1,2,3",
gr.update(),
gr.update(value="", visible=False),
gr.update(interactive=True),
"⚠️ Please select a valid example"
)
example = EXAMPLE_GALLERY_LIST[example_idx]
return (
example["prompt"],
example["seeds"],
gr.update(value=example["duration"]),
gr.update(value=example["prompt"] if self.prompt_engineering_available else "", visible=self.prompt_engineering_available),
gr.update(interactive=True),
"βœ… Example loaded! Click [πŸš€ GENERATE!] to create motion."
)
def build_ui(self):
with gr.Blocks(css=APP_CSS) as demo:
gr.LoginButton(value="Option: HuggingFace 'Login' for extra GPU quota +", size="sm")
self.use_prompt_engineering_state = gr.State(self.prompt_engineering_available)
self.output_dir_state = gr.State(self.output_dir)
# HOME Badge
gr.HTML("""
<div style="text-align: center; margin: 20px 0 10px 0;">
<a href="https://www.humangen.ai" target="_blank" style="text-decoration: none;">
<img src="https://img.shields.io/static/v1?label=🏠 HOME&message=Humangen.ai&color=3B82F6&labelColor=FACC15&style=for-the-badge" alt="HOME">
</a>
</div>
""")
# Header
gr.Markdown(HEADER_MD, elem_classes=["main-header"])
gr.HTML(SUBTITLE_MD)
with gr.Row():
# Left control panel
with gr.Column(scale=2, elem_classes=["left-panel"]):
if self.prompt_engineering_available:
input_placeholder = "Enter motion description in any language. Non-humanoid characters, multi-person, and camera motion are not supported. Click [πŸ“š Example Prompts] for ideas!"
else:
input_placeholder = "Enter English text in 'A person...' format. Less than 50 words recommended. Click [πŸ“š Example Prompts] for ideas!"
self.text_input = gr.Textbox(
label="πŸ“ Motion Description",
placeholder=input_placeholder,
lines=3,
max_lines=10,
autoscroll=False,
)
self.rewritten_text = gr.Textbox(
label="✏️ Rewritten Text (Editable)",
placeholder="LLM-rewritten text will appear here. Feel free to edit!",
interactive=True,
visible=False,
)
self.duration_slider = gr.Slider(
minimum=0.5,
maximum=12,
value=5.0,
step=0.1,
label="⏱️ Motion Duration (seconds)",
info="Adjust the length of the generated motion",
)
with gr.Row():
if self.prompt_engineering_available:
self.rewrite_btn = gr.Button(
"πŸ”„ REWRITE TEXT",
variant="secondary",
size="lg",
elem_classes=["rewrite-button"],
)
else:
self.rewrite_btn = gr.Button(
"πŸ”„ REWRITE (N/A)",
variant="secondary",
size="lg",
elem_classes=["rewrite-button"],
interactive=False,
visible=False,
)
self.generate_btn = gr.Button(
"πŸš€ GENERATE!",
variant="primary",
size="lg",
elem_classes=["generate-button"],
interactive=True,
)
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
self._build_advanced_settings()
if self.prompt_engineering_available:
status_msg = "πŸ“ Enter text and click [πŸ”„ REWRITE] or directly [πŸš€ GENERATE!]"
else:
status_msg = "πŸ“ Enter text and click [πŸš€ GENERATE!] to create motion"
self.status_output = gr.Textbox(
label="πŸ“Š Status",
value=status_msg,
lines=2,
max_lines=5,
elem_classes=["status-textbox"],
)
# FBX Download Section - Clear separation
with gr.Column(visible=False, elem_classes=["fbx-download-section"]) as self.fbx_download_section:
gr.HTML("""
<div class="fbx-download-title">
πŸ“¦ FBX FILE DOWNLOAD πŸ“¦
</div>
<p style="text-align: center; font-family: 'Comic Neue', cursive; color: #1F2937; margin-bottom: 10px;">
πŸŽ‰ Your 3D motion files are ready! Click to download.
</p>
""")
if model_inference.fbx_available:
self.fbx_files = gr.File(
label="πŸ“ Download FBX Files",
file_count="multiple",
interactive=False,
)
else:
self.fbx_files = gr.State([])
# Right display area
with gr.Column(scale=3):
self.output_display = gr.HTML(
value=get_placeholder_html(),
show_label=False,
elem_classes=["flask-display"]
)
# Example Selection with Radio Buttons
with gr.Accordion("πŸ“š Example Prompts", open=True):
# Combine all examples into one list for radio
all_examples = []
for source_name in self.all_example_data:
for text, duration in self.all_example_data[source_name]:
all_examples.append((text, duration))
# Add gallery examples
for example in EXAMPLE_GALLERY_LIST:
if (example["prompt"], example["duration"]) not in all_examples:
all_examples.append((example["prompt"], example["duration"]))
example_labels = [f"{text[:70]}..." if len(text) > 70 else text for text, _ in all_examples]
self.example_radio = gr.Radio(
choices=example_labels,
value=None,
label="Select an example prompt:",
interactive=True,
)
# Store example data for lookup
self.example_data_list = all_examples
# Footer
gr.HTML(FOOTER_MD)
self._bind_events()
return demo
def _build_advanced_settings(self):
with gr.Row():
self.seed_input = gr.Textbox(
label="🎯 Random Seeds",
value="0,1,2,3",
placeholder="e.g.: 0,1,2,3",
scale=3,
)
self.dice_btn = gr.Button(
"🎲",
variant="secondary",
size="sm",
scale=1,
min_width=50,
elem_classes=["dice-btn"],
)
self.cfg_slider = gr.Slider(
minimum=1,
maximum=10,
value=5.0,
step=0.1,
label="βš™οΈ CFG Strength",
)
def _bind_events(self):
# Generate random seeds
self.dice_btn.click(self._generate_random_seeds, outputs=[self.seed_input])
# Example radio selection
def on_radio_select(selected_label):
if selected_label is None:
return gr.update(), gr.update(), gr.update(), gr.update(visible=False), "πŸ“ Enter text or select an example"
# Find the example by label
for i, (text, duration) in enumerate(self.example_data_list):
label = f"{text[:70]}..." if len(text) > 70 else text
if label == selected_label:
return (
text,
gr.update(value=duration),
self._generate_random_seeds(),
gr.update(value=text if self.prompt_engineering_available else "", visible=self.prompt_engineering_available),
"βœ… Example loaded! Click [πŸš€ GENERATE!] to create motion."
)
return gr.update(), gr.update(), gr.update(), gr.update(), "πŸ“ Enter text or select an example"
self.example_radio.change(
fn=on_radio_select,
inputs=[self.example_radio],
outputs=[self.text_input, self.duration_slider, self.seed_input, self.rewritten_text, self.status_output],
)
# Rewrite text
if self.prompt_engineering_available:
self.rewrite_btn.click(
fn=lambda: "πŸ”„ Rewriting with AI, please wait...",
outputs=[self.status_output]
).then(
self._prompt_engineering,
inputs=[self.text_input, self.duration_slider],
outputs=[self.rewritten_text, self.generate_btn, self.duration_slider, self.status_output],
).then(
fn=lambda: gr.update(visible=True),
outputs=[self.rewritten_text],
)
# Generate motion
self.generate_btn.click(
fn=lambda: "πŸš€ Generating motion, please wait... (First run may take longer)",
outputs=[self.status_output],
).then(
generate_motion_func,
inputs=[self.text_input, self.rewritten_text, self.seed_input, self.duration_slider, self.cfg_slider],
outputs=[self.output_display, self.fbx_files],
).then(
fn=lambda fbx_list: (
(
"πŸŽ‰ DONE! Motion generated successfully! FBX files ready for download below."
if fbx_list
else "πŸŽ‰ DONE! Motion generated successfully! View the result on the right."
),
gr.update(visible=bool(fbx_list)),
),
inputs=[self.fbx_files],
outputs=[self.status_output, self.fbx_download_section],
)
def create_demo(final_model_path):
"""Create the Gradio demo with Zero GPU support."""
class Args:
model_path = final_model_path
output_dir = "output/gradio"
use_prompt_engineering = USE_PROMPT_ENGINEERING
use_text_encoder = True
args = Args()
cfg = osp.join(args.model_path, "config.yml")
ckpt = osp.join(args.model_path, "latest.ckpt")
if not osp.exists(cfg):
raise FileNotFoundError(f">>> Configuration file not found: {cfg}")
os.makedirs(args.output_dir, exist_ok=True)
ui = T2MGradioUI(args=args)
demo = ui.build_ui()
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HY-Motion Gradio App with Fireworks AI")
parser.add_argument("--port", type=int, default=7860, help="Port to listen on")
parser.add_argument("--no-prompt-engineering", action="store_true", help="Disable prompt engineering")
args = parser.parse_args()
USE_PROMPT_ENGINEERING = not args.no_prompt_engineering
try_to_download_text_encoder()
final_model_path = try_to_download_model()
model_inference = ModelInference(
final_model_path,
use_prompt_engineering=False,
use_text_encoder=True
)
model_inference.initialize_model(device="cpu")
# Generate examples on first startup
ensure_examples_generated(model_inference)
demo = create_demo(final_model_path)
demo.launch(server_name="0.0.0.0", server_port=args.port)