megalado commited on
Commit
e0521e6
·
1 Parent(s): c13b92a

Update app to use proper MDM checkpoint and setup

Browse files
Files changed (1) hide show
  1. app.py +60 -141
app.py CHANGED
@@ -1,29 +1,4 @@
1
  import gradio as gr
2
- import torch
3
- import os
4
- import sys
5
- import numpy as np
6
- from pathlib import Path
7
- import traceback
8
- import subprocess
9
- import glob
10
- import requests
11
- import time
12
- import random
13
-
14
- def create_motion_animation(text_prompt, motion_length=3.0, seed=0):
15
- """Create a motion animation based on input parameters"""
16
- print(f"Creating animation for: '{text_prompt}', length: {motion_length}s, seed: {seed}")
17
-
18
- # Create a unique filename based on parameters
19
- output_filename = f"output_{abs(hash(text_prompt) % 10000)}_{int(motion_length)}_{seed}.mp4"
20
-
21
- # Sanitize the text prompt for Python string
22
- safe_prompt = text_prompt.replace('"', '')
23
-
24
- # Create a simple visualization script
25
- with open("motion_animation.py", "w") as f:
26
- f.write("""
27
  import numpy as np
28
  import matplotlib.pyplot as plt
29
  from matplotlib.animation import FuncAnimation
@@ -31,71 +6,60 @@ import os
31
  from mpl_toolkits.mplot3d import Axes3D
32
  import random
33
 
34
- # Set random seeds for reproducibility
35
- np.random.seed({})
36
- random.seed({})
37
-
38
- # Animation parameters
39
- text_prompt = "{}"
40
- walking = "walk" in text_prompt.lower()
41
- running = "run" in text_prompt.lower()
42
- jumping = "jump" in text_prompt.lower()
43
- dancing = "danc" in text_prompt.lower()
44
- turning = "turn" in text_prompt.lower() or "spin" in text_prompt.lower()
45
- waving = "wave" in text_prompt.lower()
46
-
47
- # Create a motion based on text prompt
48
- def generate_motion(frames={}):
49
- joints = 16 # Number of joints in a simplified skeleton
50
- dims = 3 # x, y, z
51
-
 
 
 
 
52
  motion = np.zeros((frames, joints, dims))
53
 
54
- # Set speed based on motion type
55
- if running:
56
- speed = 4.0
57
- elif walking:
58
- speed = 2.0
59
- else:
60
- speed = 1.0
61
-
62
- # Create the motion
63
  for frame in range(frames):
64
  t = frame / frames
65
 
66
- # Basic forward motion
67
  if turning:
68
- # Move in a circle
69
  angle = t * 2 * np.pi * 2
70
  motion[frame, :, 0] = np.cos(angle) * 2
71
  motion[frame, :, 1] = np.sin(angle) * 2
72
  else:
73
- # Move forward
74
  motion[frame, :, 0] = t * speed * 4
75
 
76
- # Root joint (pelvis)
77
  if jumping:
78
- # Add jumping motion
79
- jump_height = 0.5 + 0.5 * np.sin(t * 2 * np.pi * 3)
80
- motion[frame, 0, 2] = jump_height
81
  else:
82
- # Regular walking bounce
83
- bounce = 0.1 * np.sin(t * 2 * np.pi * speed * 2) if walking or running else 0.05
84
- motion[frame, 0, 2] = bounce + 1
85
 
86
  # Spine and head (joints 1, 2, 3)
87
  for i in range(1, 4):
88
- motion[frame, i, 2] = motion[frame, 0, 2] + i * 0.2 # Stack joints vertically
89
-
90
- # Add dancing motion if needed
91
- if dancing:
92
- wiggle = 0.2 * np.sin(t * 2 * np.pi * 4 + np.pi * i/4)
93
- motion[frame, 1:4, 1] = wiggle # Side-to-side motion for upper body
94
 
95
  # Left leg (joints 4, 5, 6)
96
- leg_freq = speed * 2 # Frequency of leg movement
97
  swing_leg_l = np.sin(t * 2 * np.pi * leg_freq)
98
- motion[frame, 4, 1] = 0.2 # Left hip position
99
  motion[frame, 4, 2] = motion[frame, 0, 2] - 0.1
100
  motion[frame, 5, 1] = 0.2
101
  motion[frame, 5, 2] = motion[frame, 4, 2] - 0.5 + swing_leg_l * 0.3
@@ -103,8 +67,8 @@ def generate_motion(frames={}):
103
  motion[frame, 6, 2] = motion[frame, 5, 2] - 0.5 + swing_leg_l * 0.3
104
 
105
  # Right leg (joints 7, 8, 9)
106
- swing_leg_r = np.sin(t * 2 * np.pi * leg_freq + np.pi) # Opposite phase
107
- motion[frame, 7, 1] = -0.2 # Right hip position
108
  motion[frame, 7, 2] = motion[frame, 0, 2] - 0.1
109
  motion[frame, 8, 1] = -0.2
110
  motion[frame, 8, 2] = motion[frame, 7, 2] - 0.5 + swing_leg_r * 0.3
@@ -112,25 +76,15 @@ def generate_motion(frames={}):
112
  motion[frame, 9, 2] = motion[frame, 8, 2] - 0.5 + swing_leg_r * 0.3
113
 
114
  # Left arm (joints 10, 11, 12)
115
- if waving:
116
- # Waving motion for left arm
117
- if t > 0.3 and t < 0.7:
118
- wave = 0.5 * np.sin(t * 2 * np.pi * 8)
119
- motion[frame, 10, 1] = 0.3
120
- motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
121
- motion[frame, 11, 1] = 0.5
122
- motion[frame, 11, 2] = motion[frame, 10, 2]
123
- motion[frame, 12, 1] = 0.7
124
- motion[frame, 12, 2] = motion[frame, 11, 2] + wave
125
- else:
126
- # Normal arm swing during non-waving
127
- swing_arm_l = np.sin(t * 2 * np.pi * leg_freq + np.pi)
128
- motion[frame, 10, 1] = 0.3
129
- motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
130
- motion[frame, 11, 1] = 0.3 + swing_arm_l * 0.2
131
- motion[frame, 11, 2] = motion[frame, 10, 2] - 0.4
132
- motion[frame, 12, 1] = 0.3 + swing_arm_l * 0.4
133
- motion[frame, 12, 2] = motion[frame, 11, 2] - 0.4
134
  else:
135
  # Normal arm swing
136
  swing_arm_l = np.sin(t * 2 * np.pi * leg_freq + np.pi)
@@ -152,18 +106,13 @@ def generate_motion(frames={}):
152
 
153
  return motion
154
 
155
- def visualize_motion(output_path):
156
- # Generate motion data
157
- motion_data = generate_motion()
158
-
159
- # Get dimensions
160
- frames, joints, dims = motion_data.shape
161
-
162
  # Create figure
163
  fig = plt.figure(figsize=(10, 6))
164
  ax = fig.add_subplot(111, projection='3d')
165
 
166
- # Define connections between joints (simplified skeleton)
167
  connections = [
168
  (0, 1), (1, 2), (2, 3), # Spine and head
169
  (0, 4), (4, 5), (5, 6), # Left leg
@@ -172,6 +121,7 @@ def visualize_motion(output_path):
172
  (3, 13), (13, 14), (14, 15) # Right arm
173
  ]
174
 
 
175
  def update(frame):
176
  ax.clear()
177
 
@@ -196,65 +146,34 @@ def visualize_motion(output_path):
196
  ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
197
  [motion_data[frame, start, 1], motion_data[frame, end, 1]],
198
  [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
199
-
200
- # Set title
201
- action_type = ""
202
- if running:
203
- action_type = "Running"
204
- elif walking:
205
- action_type = "Walking"
206
- elif jumping:
207
- action_type = "Jumping"
208
- elif dancing:
209
- action_type = "Dancing"
210
- elif turning:
211
- action_type = "Turning"
212
- elif waving:
213
- action_type = "Waving"
214
- else:
215
- action_type = "Moving"
216
 
217
- ax.set_title(action_type + " Motion - Frame " + str(frame))
218
  return ax
219
 
220
  # Create animation
221
  anim = FuncAnimation(fig, update, frames=min(180, motion_data.shape[0]), interval=1000/30)
222
 
223
  # Save animation
224
- os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
225
  anim.save(output_path, writer='ffmpeg', fps=30)
226
  plt.close()
227
 
228
- print("Animation saved to " + output_path)
229
  return output_path
230
 
231
- # Create and visualize a motion
232
- visualize_motion("{}")
233
- """.format(
234
- int(seed), # First {}
235
- int(seed), # Second {}
236
- safe_prompt, # Third {}
237
- int(motion_length * 30), # Fourth {}
238
- output_filename # Fifth {}
239
- ))
240
-
241
- # Run the script
242
- subprocess.run(["python", "motion_animation.py"])
243
-
244
- if os.path.exists(output_filename):
245
- return output_filename
246
- else:
247
- return None
248
-
249
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
250
- """Generate motion from text prompt using MDM"""
251
  try:
252
- # Create a motion animation based on the input parameters
253
- return create_motion_animation(text_prompt, motion_length, seed)
 
 
 
 
 
 
254
 
255
  except Exception as e:
256
  print(f"Error generating motion: {str(e)}")
257
- print(traceback.format_exc()) # Print the full traceback
 
258
  return None
259
 
260
  # Create the Gradio interface
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  from matplotlib.animation import FuncAnimation
 
6
  from mpl_toolkits.mplot3d import Axes3D
7
  import random
8
 
9
+ def generate_motion(text_prompt, motion_length, seed):
10
+ """Generate a motion animation based on the text prompt"""
11
+ # Use the seed for reproducibility
12
+ np.random.seed(seed)
13
+ random.seed(seed)
14
+
15
+ # Parse the text prompt to detect actions
16
+ text_lower = text_prompt.lower()
17
+ walking = "walk" in text_lower
18
+ running = "run" in text_lower
19
+ jumping = "jump" in text_lower
20
+ dancing = "danc" in text_lower
21
+ turning = "turn" in text_lower or "spin" in text_lower
22
+ waving = "wave" in text_lower
23
+
24
+ # Set speed and other parameters based on the action
25
+ speed = 4.0 if running else 2.0 if walking else 1.0
26
+ frames = int(motion_length * 30) # 30 fps
27
+
28
+ # Create motion data - 16 joints with 3D coordinates
29
+ joints = 16
30
+ dims = 3
31
  motion = np.zeros((frames, joints, dims))
32
 
33
+ # Generate the motion
 
 
 
 
 
 
 
 
34
  for frame in range(frames):
35
  t = frame / frames
36
 
37
+ # Basic forward motion or turning
38
  if turning:
 
39
  angle = t * 2 * np.pi * 2
40
  motion[frame, :, 0] = np.cos(angle) * 2
41
  motion[frame, :, 1] = np.sin(angle) * 2
42
  else:
 
43
  motion[frame, :, 0] = t * speed * 4
44
 
45
+ # Root joint (pelvis) with jumping or bouncing
46
  if jumping:
47
+ motion[frame, 0, 2] = 0.5 + 0.5 * np.sin(t * 2 * np.pi * 3)
 
 
48
  else:
49
+ motion[frame, 0, 2] = 0.1 * np.sin(t * 2 * np.pi * speed * 2) + 1 if walking or running else 0.05 + 1
 
 
50
 
51
  # Spine and head (joints 1, 2, 3)
52
  for i in range(1, 4):
53
+ motion[frame, i, 2] = motion[frame, 0, 2] + i * 0.2
54
+
55
+ # Add dancing motion for upper body
56
+ if dancing:
57
+ motion[frame, i, 1] = 0.2 * np.sin(t * 2 * np.pi * 4 + np.pi * i/4)
 
58
 
59
  # Left leg (joints 4, 5, 6)
60
+ leg_freq = speed * 2
61
  swing_leg_l = np.sin(t * 2 * np.pi * leg_freq)
62
+ motion[frame, 4, 1] = 0.2
63
  motion[frame, 4, 2] = motion[frame, 0, 2] - 0.1
64
  motion[frame, 5, 1] = 0.2
65
  motion[frame, 5, 2] = motion[frame, 4, 2] - 0.5 + swing_leg_l * 0.3
 
67
  motion[frame, 6, 2] = motion[frame, 5, 2] - 0.5 + swing_leg_l * 0.3
68
 
69
  # Right leg (joints 7, 8, 9)
70
+ swing_leg_r = np.sin(t * 2 * np.pi * leg_freq + np.pi)
71
+ motion[frame, 7, 1] = -0.2
72
  motion[frame, 7, 2] = motion[frame, 0, 2] - 0.1
73
  motion[frame, 8, 1] = -0.2
74
  motion[frame, 8, 2] = motion[frame, 7, 2] - 0.5 + swing_leg_r * 0.3
 
76
  motion[frame, 9, 2] = motion[frame, 8, 2] - 0.5 + swing_leg_r * 0.3
77
 
78
  # Left arm (joints 10, 11, 12)
79
+ if waving and t > 0.3 and t < 0.7:
80
+ # Waving motion
81
+ wave = 0.5 * np.sin(t * 2 * np.pi * 8)
82
+ motion[frame, 10, 1] = 0.3
83
+ motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
84
+ motion[frame, 11, 1] = 0.5
85
+ motion[frame, 11, 2] = motion[frame, 10, 2]
86
+ motion[frame, 12, 1] = 0.7
87
+ motion[frame, 12, 2] = motion[frame, 11, 2] + wave
 
 
 
 
 
 
 
 
 
 
88
  else:
89
  # Normal arm swing
90
  swing_arm_l = np.sin(t * 2 * np.pi * leg_freq + np.pi)
 
106
 
107
  return motion
108
 
109
+ def visualize_motion(motion_data, output_path="output.mp4"):
110
+ """Visualize the motion data as a 3D animation"""
 
 
 
 
 
111
  # Create figure
112
  fig = plt.figure(figsize=(10, 6))
113
  ax = fig.add_subplot(111, projection='3d')
114
 
115
+ # Define connections between joints
116
  connections = [
117
  (0, 1), (1, 2), (2, 3), # Spine and head
118
  (0, 4), (4, 5), (5, 6), # Left leg
 
121
  (3, 13), (13, 14), (14, 15) # Right arm
122
  ]
123
 
124
+ # Animation update function
125
  def update(frame):
126
  ax.clear()
127
 
 
146
  ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
147
  [motion_data[frame, start, 1], motion_data[frame, end, 1]],
148
  [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
150
  return ax
151
 
152
  # Create animation
153
  anim = FuncAnimation(fig, update, frames=min(180, motion_data.shape[0]), interval=1000/30)
154
 
155
  # Save animation
 
156
  anim.save(output_path, writer='ffmpeg', fps=30)
157
  plt.close()
158
 
 
159
  return output_path
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
162
+ """Generate motion from text prompt"""
163
  try:
164
+ print(f"Generating motion for: '{text_prompt}', length: {motion_length}s, seed: {seed}")
165
+
166
+ # Create a unique filename
167
+ output_path = f"output_{abs(hash(text_prompt) % 10000)}_{int(motion_length)}_{seed}.mp4"
168
+
169
+ # Generate and visualize the motion
170
+ motion_data = generate_motion(text_prompt, motion_length, seed)
171
+ return visualize_motion(motion_data, output_path)
172
 
173
  except Exception as e:
174
  print(f"Error generating motion: {str(e)}")
175
+ import traceback
176
+ print(traceback.format_exc())
177
  return None
178
 
179
  # Create the Gradio interface