testMDM / app.py
megalado
Improve MDM integration for better animation quality
96802a8
# app.py
"""
Motion Diffusion Demo on Hugging Face Spaces
-------------------------------------------
Generate human motion from a text prompt with **Motion‑Diffusion‑Model (MDM)**.
We keep the user‑supplied layout:
* repo folder → motion-diffusion-model/
* checkpoint file → checkpoints/opt000750000.pt
The script simply changes into the repo and executes `python -m sample.generate`,
appearing exactly like the original workflow that was already working for you.
"""
from __future__ import annotations
import sys
import subprocess
import traceback
from pathlib import Path
from typing import Optional
import gradio as gr
# ---------------------------------------------------------------------------
# Config – edit only if your Space layout changes
# ---------------------------------------------------------------------------
REPO_DIR = "motion-diffusion-model" # cloned repo folder name
CHECKPOINT_PATH = "checkpoints/opt000750000.pt" # user‑supplied ckpt path
OUTPUT_DIR = "output" # final MP4 destination
MAX_LEN_SEC = 9.8 # model’s max sequence length
# ---------------------------------------------------------------------------
# Helper: make sure repo exists + on sys.path
# ---------------------------------------------------------------------------
def ensure_repo_ready() -> None:
if not Path(REPO_DIR).exists():
print("[setup] Cloning Motion‑Diffusion‑Model repo …")
subprocess.run(
[
"git",
"clone",
"https://github.com/GuyTevet/motion-diffusion-model.git",
REPO_DIR,
],
check=True,
)
repo_abs = str(Path(REPO_DIR).resolve())
if repo_abs not in sys.path:
sys.path.insert(0, repo_abs)
# ---------------------------------------------------------------------------
# Core Generation
# ---------------------------------------------------------------------------
def run_mdm(prompt: str, length: float, seed: int) -> Optional[str]:
"""Call the official generator inside the repo and return path to MP4."""
ensure_repo_ready()
ckpt = Path(CHECKPOINT_PATH).resolve()
if not ckpt.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
Path(OUTPUT_DIR).mkdir(exist_ok=True)
cmd = [
"python",
"-m",
"sample.generate",
"--model_path",
str(ckpt),
"--text_prompt",
prompt,
"--motion_length",
f"{min(length, MAX_LEN_SEC):.2f}",
"--seed",
str(seed),
]
print("[run]", " ".join(cmd))
try:
subprocess.run(cmd, cwd=REPO_DIR, check=True)
except subprocess.CalledProcessError as exc:
print("[error] Generator failed:", exc)
return None
# Look for the newest MP4 inside the repo after generation
mp4_files = list(Path(REPO_DIR).rglob("*.mp4"))
if not mp4_files:
print("[warn] No MP4 produced.")
return None
newest = max(mp4_files, key=lambda p: p.stat().st_mtime)
final_path = Path(OUTPUT_DIR) / newest.name
newest.replace(final_path)
print(f"[ok] Motion video saved to {final_path}")
return str(final_path)
# ---------------------------------------------------------------------------
# Gradio wrapper
# ---------------------------------------------------------------------------
def text_to_motion(prompt: str, length: float = 3.0, seed: int = 0):
try:
return run_mdm(prompt, length, seed)
except Exception:
print(traceback.format_exc())
return None
# ---------------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------------
demo = gr.Interface(
fn=text_to_motion,
inputs=[
gr.Textbox(
label="Text Prompt",
lines=3,
value="A person walks forward and waves.",
),
gr.Slider(
minimum=1.0,
maximum=MAX_LEN_SEC,
step=0.1,
value=3.0,
label="Motion Length (seconds)",
),
gr.Number(label="Random Seed", value=0, precision=0),
],
outputs=gr.Video(label="Generated Motion"),
title="Motion Diffusion Model Demo (HumanML)",
description=(
"Describe an action — e.g. 'A person runs in a circle and jumps'. "
"The HumanML checkpoint returns a skeletal MP4."
),
)
# ---------------------------------------------------------------------------
# Launch
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.launch()