megalado commited on
Commit
b7eb387
·
1 Parent(s): d56c9e8

Improve MDM integration for better animation quality

Browse files
Files changed (1) hide show
  1. app.py +130 -380
app.py CHANGED
@@ -1,406 +1,156 @@
1
- import gradio as gr
2
- import torch
3
- import os
4
- import sys
5
- import numpy as np
6
- from pathlib import Path
7
- import traceback
8
- import subprocess
 
 
 
 
 
 
 
 
 
 
9
 
10
- def ensure_mdm_repo():
11
- """Ensure the MDM repository is cloned and set up"""
12
- if not Path("motion-diffusion-model").exists():
13
- print("Cloning Motion Diffusion Model repository...")
14
- subprocess.run(["git", "clone", "https://github.com/GuyTevet/motion-diffusion-model.git"])
15
-
16
- # Set up the repository
17
- print("Setting up the repository...")
18
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
19
-
20
- # Add necessary files
21
- os.chdir("motion-diffusion-model")
22
- subprocess.run(["bash", "prepare/download_smpl_files.sh"])
23
- subprocess.run(["bash", "prepare/download_glove.sh"])
24
- subprocess.run(["bash", "prepare/download_t2m_evaluators.sh"])
25
- os.chdir("..")
26
-
27
- # Add the repository to the Python path
28
- if "./motion-diffusion-model" not in sys.path:
29
- sys.path.append("./motion-diffusion-model")
30
-
31
- def text_to_motion(text_prompt, motion_length=3.0, seed=0):
32
- """Generate motion from text prompt using MDM"""
33
- try:
34
- print(f"Generating motion for: '{text_prompt}', length: {motion_length}s, seed: {seed}")
35
-
36
- # Ensure the MDM repository is set up
37
- ensure_mdm_repo()
38
-
39
- # Create output directory
40
- os.makedirs("output", exist_ok=True)
41
-
42
- # Get absolute path to the checkpoint
43
- checkpoint_path = os.path.abspath("checkpoints/opt000750000.pt")
44
- print(f"Using checkpoint: {checkpoint_path}")
45
-
46
- # Change to the MDM repository directory
47
- original_dir = os.getcwd()
48
- os.chdir("motion-diffusion-model")
49
-
50
- # List the sample directory to see what scripts are available
51
- print("Available scripts in sample directory:")
52
- if os.path.exists("sample"):
53
- for file in os.listdir("sample"):
54
- print(f" - {file}")
55
-
56
- # Find the generate script
57
- generate_script = None
58
- for root, dirs, files in os.walk("."):
59
- for file in files:
60
- if file.endswith(".py") and "generate" in file:
61
- generate_script = os.path.join(root, file)
62
- print(f"Found generate script: {generate_script}")
63
- break
64
- if generate_script:
65
- break
66
-
67
- if not generate_script:
68
- print("Could not find generate script")
69
- os.chdir(original_dir)
70
- return None
71
-
72
- # Create a simple Python script that uses our model
73
- with open("run_mdm.py", "w") as f:
74
- f.write("""
75
  import os
76
  import sys
77
- import torch
78
- import numpy as np
79
  from pathlib import Path
 
80
 
81
- # Add current directory to path
82
- sys.path.insert(0, os.getcwd())
83
-
84
- # Import required modules
85
- from utils.model_util import create_model_and_diffusion, load_saved_model
86
- from utils import dist_util
87
 
88
- def generate_motion(model_path, text_prompt, motion_length, seed):
89
- # Set up model
90
- model, diffusion = create_model_and_diffusion(
91
- model_path=model_path,
92
- dataset='humanml',
93
- diffusion_steps=1000,
94
- num_frames=motion_length * 20, # Assuming 20 fps
95
- )
96
-
97
- # Load checkpoint
98
- load_saved_model(model, model_path)
99
- model.eval()
100
-
101
- # Set seed
102
- torch.manual_seed(seed)
103
-
104
- # Generate motion
105
- with torch.no_grad():
106
- # Process text
107
- text_emb = model.encode_text(text_prompt)
108
-
109
- # Generate motion
110
- samples = diffusion.p_sample_loop(
111
- model.forward_with_text,
112
- shape=(1, model.njoints, model.nfeats, int(motion_length * 20)),
113
- text_emb=text_emb,
114
- clip_denoised=True,
115
  )
116
-
117
- # Save to file
118
- os.makedirs('output', exist_ok=True)
119
- output_path = f'output/motion_{abs(hash(text_prompt) % 10000)}_{int(motion_length)}_{seed}.mp4'
120
-
121
- # Visualize and save
122
- from visualization.visualize import visualize
123
- visualize(samples.cpu().numpy(), output_path)
124
-
125
- return output_path
126
 
127
- if __name__ == '__main__':
128
- import argparse
129
-
130
- parser = argparse.ArgumentParser()
131
- parser.add_argument('--model_path', type=str, required=True)
132
- parser.add_argument('--text_prompt', type=str, required=True)
133
- parser.add_argument('--motion_length', type=float, default=3.0)
134
- parser.add_argument('--seed', type=int, default=0)
135
-
136
- args = parser.parse_args()
137
-
138
- output_path = generate_motion(
139
- args.model_path,
140
- args.text_prompt,
141
- args.motion_length,
142
- args.seed
143
- )
144
-
145
- print(f"Generated motion saved to: {output_path}")
146
- """)
147
-
148
- # Run our custom script
149
- cmd = [
150
- "python",
151
- "run_mdm.py",
152
- "--model_path", checkpoint_path,
153
- "--text_prompt", text_prompt,
154
- "--motion_length", str(motion_length),
155
- "--seed", str(int(seed))
156
- ]
157
-
158
- print(f"Running command: {' '.join(cmd)}")
159
- result = subprocess.run(cmd, capture_output=True, text=True)
160
-
161
- # Print the output for debugging
162
- print("Command output:", result.stdout)
163
- if result.stderr:
164
- print("Command error:", result.stderr)
165
-
166
- # Check for output files
167
- output_mp4 = None
168
- for root, dirs, files in os.walk("."):
169
- for file in files:
170
- if file.endswith(".mp4"):
171
- output_mp4 = os.path.join(root, file)
172
- print(f"Found output file: {output_mp4}")
173
- break
174
- if output_mp4:
175
- break
176
-
177
- # Return to the original directory
178
- os.chdir(original_dir)
179
-
180
- # If we found an output file, copy it to our output directory
181
- if output_mp4:
182
- output_path = f"output/output_{abs(hash(text_prompt) % 10000)}_{int(motion_length)}_{seed}.mp4"
183
- subprocess.run(["cp", os.path.join("motion-diffusion-model", output_mp4), output_path])
184
- print(f"Copied output to {output_path}")
185
- return output_path
186
-
187
- # Fall back to simplified motion generation
188
- print("MDM generation failed, falling back to simplified motion")
189
- return create_simplified_motion(text_prompt, motion_length, seed)
190
-
191
- except Exception as e:
192
- print(f"Error generating motion: {str(e)}")
193
- print(traceback.format_exc())
194
-
195
- # Fall back to simplified motion generation
196
- try:
197
- return create_simplified_motion(text_prompt, motion_length, seed)
198
- except:
199
- return None
200
-
201
- def create_simplified_motion(text_prompt, motion_length, seed):
202
- """Create a simplified motion animation as fallback"""
203
- print("Creating simplified motion animation...")
204
-
205
- # Create output directory
206
- os.makedirs("output", exist_ok=True)
207
- output_path = f"output/simplified_{abs(hash(text_prompt) % 10000)}_{int(motion_length)}_{seed}.mp4"
208
-
209
- # Create a standalone script to generate the motion
210
- with open("simplified_motion.py", "w") as f:
211
- f.write(f"""
212
- import numpy as np
213
- import matplotlib.pyplot as plt
214
- from matplotlib.animation import FuncAnimation
215
- import os
216
- from mpl_toolkits.mplot3d import Axes3D
217
-
218
- # Set random seed for reproducibility
219
- np.random.seed({seed})
220
-
221
- # Parse the text prompt to detect actions
222
- text_lower = "{text_prompt.lower()}"
223
- walking = "walk" in text_lower
224
- running = "run" in text_lower
225
- jumping = "jump" in text_lower
226
- dancing = "danc" in text_lower
227
- turning = "turn" in text_lower or "spin" in text_lower
228
- waving = "wave" in text_lower
229
 
230
- # Set parameters
231
- frames = int({motion_length} * 30) # 30 fps
232
- speed = 4.0 if running else 2.0 if walking else 1.0
 
 
233
 
234
- # Create motion data - 16 joints with 3D coordinates
235
- joints = 16
236
- dims = 3
237
- motion = np.zeros((frames, joints, dims))
238
 
239
- # Generate the motion
240
- for frame in range(frames):
241
- t = frame / frames
242
-
243
- # Basic forward motion or turning
244
- if turning:
245
- angle = t * 2 * np.pi * 2
246
- motion[frame, :, 0] = np.cos(angle) * 2
247
- motion[frame, :, 1] = np.sin(angle) * 2
248
- else:
249
- motion[frame, :, 0] = t * speed * 4
250
-
251
- # Root joint (pelvis) with jumping or bouncing
252
- if jumping:
253
- motion[frame, 0, 2] = 0.5 + 0.5 * np.sin(t * 2 * np.pi * 3)
254
- else:
255
- motion[frame, 0, 2] = 0.1 * np.sin(t * 2 * np.pi * speed * 2) + 1 if walking or running else 0.05 + 1
256
-
257
- # Spine and head (joints 1, 2, 3)
258
- for i in range(1, 4):
259
- motion[frame, i, 2] = motion[frame, 0, 2] + i * 0.2
260
-
261
- # Add dancing motion for upper body
262
- if dancing:
263
- motion[frame, i, 1] = 0.2 * np.sin(t * 2 * np.pi * 4 + np.pi * i/4)
264
-
265
- # Left leg (joints 4, 5, 6)
266
- leg_freq = speed * 2
267
- swing_leg_l = np.sin(t * 2 * np.pi * leg_freq)
268
- motion[frame, 4, 1] = 0.2
269
- motion[frame, 4, 2] = motion[frame, 0, 2] - 0.1
270
- motion[frame, 5, 1] = 0.2
271
- motion[frame, 5, 2] = motion[frame, 4, 2] - 0.5 + swing_leg_l * 0.3
272
- motion[frame, 6, 1] = 0.2
273
- motion[frame, 6, 2] = motion[frame, 5, 2] - 0.5 + swing_leg_l * 0.3
274
-
275
- # Right leg (joints 7, 8, 9)
276
- swing_leg_r = np.sin(t * 2 * np.pi * leg_freq + np.pi)
277
- motion[frame, 7, 1] = -0.2
278
- motion[frame, 7, 2] = motion[frame, 0, 2] - 0.1
279
- motion[frame, 8, 1] = -0.2
280
- motion[frame, 8, 2] = motion[frame, 7, 2] - 0.5 + swing_leg_r * 0.3
281
- motion[frame, 9, 1] = -0.2
282
- motion[frame, 9, 2] = motion[frame, 8, 2] - 0.5 + swing_leg_r * 0.3
283
-
284
- # Left arm (joints 10, 11, 12)
285
- if waving and t > 0.3 and t < 0.7:
286
- # Waving motion
287
- wave = 0.5 * np.sin(t * 2 * np.pi * 8)
288
- motion[frame, 10, 1] = 0.3
289
- motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
290
- motion[frame, 11, 1] = 0.5
291
- motion[frame, 11, 2] = motion[frame, 10, 2]
292
- motion[frame, 12, 1] = 0.7
293
- motion[frame, 12, 2] = motion[frame, 11, 2] + wave
294
- else:
295
- # Normal arm swing
296
- swing_arm_l = np.sin(t * 2 * np.pi * leg_freq + np.pi)
297
- motion[frame, 10, 1] = 0.3
298
- motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
299
- motion[frame, 11, 1] = 0.3 + swing_arm_l * 0.2
300
- motion[frame, 11, 2] = motion[frame, 10, 2] - 0.4
301
- motion[frame, 12, 1] = 0.3 + swing_arm_l * 0.4
302
- motion[frame, 12, 2] = motion[frame, 11, 2] - 0.4
303
-
304
- # Right arm (joints 13, 14, 15)
305
- swing_arm_r = np.sin(t * 2 * np.pi * leg_freq)
306
- motion[frame, 13, 1] = -0.3
307
- motion[frame, 13, 2] = motion[frame, 3, 2] - 0.2
308
- motion[frame, 14, 1] = -0.3 + swing_arm_r * 0.2
309
- motion[frame, 14, 2] = motion[frame, 13, 2] - 0.4
310
- motion[frame, 15, 1] = -0.3 + swing_arm_r * 0.4
311
- motion[frame, 15, 2] = motion[frame, 14, 2] - 0.4
312
 
313
- # Create figure for visualization
314
- fig = plt.figure(figsize=(10, 6))
315
- ax = fig.add_subplot(111, projection='3d')
316
 
317
- # Define connections between joints
318
- connections = [
319
- (0, 1), (1, 2), (2, 3), # Spine and head
320
- (0, 4), (4, 5), (5, 6), # Left leg
321
- (0, 7), (7, 8), (8, 9), # Right leg
322
- (3, 10), (10, 11), (11, 12), # Left arm
323
- (3, 13), (13, 14), (14, 15) # Right arm
324
- ]
325
 
326
- # Animation update function
327
- def update(frame):
328
- ax.clear()
329
-
330
- # Set axis limits
331
- max_range = max(4, np.max(np.abs(motion)))
332
- ax.set_xlim([-max_range/2, max_range/2 + motion[frame, 0, 0]])
333
- ax.set_ylim([-max_range/2, max_range/2])
334
- ax.set_zlim([0, max_range])
335
-
336
- # Set labels
337
- ax.set_xlabel('X (forward)')
338
- ax.set_ylabel('Y (sideways)')
339
- ax.set_zlabel('Z (upward)')
340
-
341
- # Plot joints
342
- ax.scatter(motion[frame, :, 0],
343
- motion[frame, :, 1],
344
- motion[frame, :, 2], c='b', marker='o')
345
-
346
- # Plot connections
347
- for start, end in connections:
348
- ax.plot([motion[frame, start, 0], motion[frame, end, 0]],
349
- [motion[frame, start, 1], motion[frame, end, 1]],
350
- [motion[frame, start, 2], motion[frame, end, 2]], 'r-')
351
-
352
- # Add action type to title
353
- action_type = ""
354
- if running:
355
- action_type = "Running"
356
- elif walking:
357
- action_type = "Walking"
358
- elif jumping:
359
- action_type = "Jumping"
360
- elif dancing:
361
- action_type = "Dancing"
362
- elif turning:
363
- action_type = "Turning"
364
- elif waving:
365
- action_type = "Waving"
366
- else:
367
- action_type = "Moving"
368
-
369
- ax.set_title(action_type + " Motion - Frame " + str(frame))
370
- return ax
371
 
372
- # Create animation
373
- anim = FuncAnimation(fig, update, frames=min(frames, 180), interval=1000/30)
 
 
 
 
374
 
375
- # Save animation
376
- os.makedirs(os.path.dirname("{output_path}") or '.', exist_ok=True)
377
- anim.save("{output_path}", writer='ffmpeg', fps=30)
378
- plt.close()
379
 
380
- print("Animation saved to {output_path}")
381
- """)
382
-
383
- # Run the script
384
- subprocess.run(["python", "simplified_motion.py"])
385
-
386
- if os.path.exists(output_path):
387
- return output_path
388
- else:
389
- return None
390
 
391
- # Create the Gradio interface
392
  demo = gr.Interface(
393
  fn=text_to_motion,
394
  inputs=[
395
- gr.Textbox(label="Text Prompt", placeholder="A person walks forward, then turns left", lines=3, value="A person walking"),
396
- gr.Slider(minimum=1.0, maximum=9.8, value=3.0, label="Motion Length (seconds)"),
397
- gr.Number(label="Random Seed", value=0)
 
 
 
 
 
 
 
 
 
 
398
  ],
399
  outputs=gr.Video(label="Generated Motion"),
400
- title="Motion Diffusion Model Demo",
401
- description="Generate human motions from text descriptions. Try prompts with actions like 'walk', 'run', 'jump', 'dance', 'turn', or 'wave'."
 
 
 
402
  )
403
 
404
- # Launch the app
 
 
 
405
  if __name__ == "__main__":
406
- demo.launch()
 
1
+ # app.py
2
+ """
3
+ Motion Diffusion Demo on Hugging Face Spaces
4
+ -------------------------------------------
5
+ Generates human motion from a text prompt using the Motion-Diffusion-Model (MDM)
6
+ checkpoint already uploaded to this Space.
7
+
8
+ Key points
9
+ ~~~~~~~~~~
10
+ * **Repo location** : motion-diffusion-model/
11
+ * **Checkpoint location** : checkpoints/opt000750000.pt (path kept intact)
12
+ * We call the official `sample.generate` CLI so we inherit every default the
13
+ authors bundled with the checkpoint (vocab, SMPL params, diffusion schedule …).
14
+ * If anything goes wrong the function falls back to returning `None`, allowing
15
+ Gradio to show an empty result instead of crashing the Space.
16
+ """
17
+
18
+ from __future__ import annotations
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import os
21
  import sys
22
+ import subprocess
23
+ import traceback
24
  from pathlib import Path
25
+ from typing import Optional
26
 
27
+ import gradio as gr
 
 
 
 
 
28
 
29
+ # ---------------------------------------------------------------------------
30
+ # Configuration
31
+ # ---------------------------------------------------------------------------
32
+
33
+ REPO_DIR = "motion-diffusion-model" # repo folder (already synced)
34
+ CHECKPOINT_PATH = "checkpoints/opt000750000.pt" # keep as-is per user request
35
+ OUTPUT_DIR = "output" # where final MP4 files live
36
+ MAX_LEN_SEC = 9.8 # model’s hard limit
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Helper functions
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def ensure_repo_ready() -> None:
43
+ """Clone the repo only if it isn’t present and push it onto sys.path."""
44
+ if not Path(REPO_DIR).exists():
45
+ print("[setup] Cloning Motion-Diffusion-Model repo …")
46
+ subprocess.run(
47
+ [
48
+ "git",
49
+ "clone",
50
+ "https://github.com/GuyTevet/motion-diffusion-model.git",
51
+ REPO_DIR,
52
+ ],
53
+ check=True,
 
 
54
  )
 
 
 
 
 
 
 
 
 
 
55
 
56
+ repo_abs = str(Path(REPO_DIR).resolve())
57
+ if repo_abs not in sys.path:
58
+ sys.path.insert(0, repo_abs)
59
+
60
+
61
+ def run_mdm(prompt: str, length: float, seed: int) -> Optional[str]:
62
+ """Generate a motion MP4 via the authors’ sample.generate script."""
63
+ ensure_repo_ready()
64
+
65
+ ckpt = Path(CHECKPOINT_PATH).resolve()
66
+ if not ckpt.exists():
67
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
68
+
69
+ # The script creates its own result folder; we just need somewhere to move
70
+ # the freshest MP4 afterwards.
71
+ Path(OUTPUT_DIR).mkdir(exist_ok=True)
72
+
73
+ cmd = [
74
+ "python",
75
+ "-m",
76
+ "sample.generate",
77
+ "--model_path",
78
+ str(ckpt),
79
+ "--text_prompt",
80
+ prompt,
81
+ "--motion_length",
82
+ f"{min(length, MAX_LEN_SEC):.2f}",
83
+ "--seed",
84
+ str(seed),
85
+ ]
86
+
87
+ print("[run]", " ".join(cmd))
88
+ try:
89
+ subprocess.run(cmd, cwd=REPO_DIR, check=True)
90
+ except subprocess.CalledProcessError as exc:
91
+ print("[error] sample.generate failed:", exc)
92
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Grab the newest MP4 produced by the script
95
+ mp4_files = list(Path(REPO_DIR).rglob("*.mp4"))
96
+ if not mp4_files:
97
+ print("[warn] No MP4 file produced by the generator.")
98
+ return None
99
 
100
+ newest = max(mp4_files, key=lambda p: p.stat().st_mtime)
101
+ final_path = Path(OUTPUT_DIR) / newest.name
102
+ newest.replace(final_path) # move instead of copy to save disk/quota
 
103
 
104
+ print(f"[ok] Motion video saved to {final_path}")
105
+ return str(final_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
 
 
 
107
 
108
+ def fallback_motion(prompt: str, length: float, seed: int) -> Optional[str]:
109
+ """Placeholder fallback – returns None so the UI stays clean."""
110
+ print("[fallback] Returning empty result.")
111
+ return None
 
 
 
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ def text_to_motion(prompt: str, length: float = 3.0, seed: int = 0):
115
+ try:
116
+ return run_mdm(prompt, length, seed) or fallback_motion(prompt, length, seed)
117
+ except Exception:
118
+ print(traceback.format_exc())
119
+ return fallback_motion(prompt, length, seed)
120
 
 
 
 
 
121
 
122
+ # ---------------------------------------------------------------------------
123
+ # Gradio UI
124
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
125
 
 
126
  demo = gr.Interface(
127
  fn=text_to_motion,
128
  inputs=[
129
+ gr.Textbox(
130
+ label="Text Prompt",
131
+ lines=3,
132
+ value="A person walks forward and waves.",
133
+ ),
134
+ gr.Slider(
135
+ minimum=1.0,
136
+ maximum=MAX_LEN_SEC,
137
+ step=0.1,
138
+ value=3.0,
139
+ label="Motion Length (seconds)",
140
+ ),
141
+ gr.Number(label="Random Seed", value=0, precision=0),
142
  ],
143
  outputs=gr.Video(label="Generated Motion"),
144
+ title="Motion Diffusion Model Demo (HumanML)",
145
+ description=(
146
+ "Enter an action description (e.g. 'A person runs in a circle and jumps').\n"
147
+ "The model returns a skeletal MP4 generated with the HumanML checkpoint."
148
+ ),
149
  )
150
 
151
+ # ---------------------------------------------------------------------------
152
+ # Launch
153
+ # ---------------------------------------------------------------------------
154
+
155
  if __name__ == "__main__":
156
+ demo.launch()