megalado commited on
Commit
5ee9026
·
1 Parent(s): 5309894

Update app to use proper MDM checkpoint and setup

Browse files
Files changed (1) hide show
  1. app.py +127 -132
app.py CHANGED
@@ -37,67 +37,112 @@ def download_checkpoint():
37
 
38
  return checkpoint_path
39
 
40
- def clone_mdm_repo():
41
- """Clone the MDM repository if not present"""
42
  if not Path("motion-diffusion-model").exists():
43
- print("Cloning Motion Diffusion Model repository...")
44
- subprocess.run(["git", "clone", "https://github.com/GuyTevet/motion-diffusion-model.git"])
45
- print("Installing Spacy language model...")
46
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
47
-
48
- # Download additional required files
49
- print("Downloading additional required files...")
50
- os.chdir("motion-diffusion-model")
51
- subprocess.run(["bash", "prepare/download_smpl_files.sh"])
52
- subprocess.run(["bash", "prepare/download_glove.sh"])
53
- subprocess.run(["bash", "prepare/download_t2m_evaluators.sh"])
54
- os.chdir("..")
55
-
56
- print("MDM repository setup complete")
57
- else:
58
- print("MDM repository already exists")
59
 
60
- def text_to_motion(text_prompt, motion_length=3.0, seed=0):
61
- """Generate motion from text prompt using MDM"""
62
- try:
63
- # Clone the MDM repository
64
- clone_mdm_repo()
65
-
66
- # Download the recommended checkpoint
67
- checkpoint_path = download_checkpoint()
68
- absolute_checkpoint_path = os.path.abspath(checkpoint_path)
69
-
70
- # Set up environment
71
- original_dir = os.getcwd()
72
- os.chdir("motion-diffusion-model")
73
-
74
- # Add current directory to Python path
75
- sys.path.insert(0, os.getcwd())
76
-
77
- # Create a simple visualization script
78
- with open("visualize_motion.py", "w") as f:
79
- f.write("""
80
  import numpy as np
81
  import matplotlib.pyplot as plt
82
  from matplotlib.animation import FuncAnimation
83
- import sys
84
  import os
85
  from mpl_toolkits.mplot3d import Axes3D
86
 
87
- def visualize_motion(motion_file, output_path):
88
- # Load the motion data
89
- motion_data = np.load(motion_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Get dimensions
92
  frames, joints, dims = motion_data.shape
93
 
94
  # Create figure
95
- fig = plt.figure(figsize=(10, 10))
96
  ax = fig.add_subplot(111, projection='3d')
97
 
98
- # Define connections between joints (simplified)
99
  connections = [
100
- (0, 1), (1, 2), (2, 3), # Spine
101
  (0, 4), (4, 5), (5, 6), # Left leg
102
  (0, 7), (7, 8), (8, 9), # Right leg
103
  (3, 10), (10, 11), (11, 12), # Left arm
@@ -108,114 +153,64 @@ def visualize_motion(motion_file, output_path):
108
  ax.clear()
109
 
110
  # Set axis limits
111
- max_range = np.max(np.abs(motion_data))
112
- ax.set_xlim([-max_range, max_range])
113
- ax.set_ylim([-max_range, max_range])
114
- ax.set_zlim([-max_range, max_range])
 
 
 
 
115
 
116
  # Plot joints
117
  ax.scatter(motion_data[frame, :, 0],
118
- motion_data[frame, :, 1],
119
- motion_data[frame, :, 2], c='b', marker='o')
120
 
121
  # Plot connections
122
  for start, end in connections:
123
- if start < joints and end < joints:
124
- ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
125
- [motion_data[frame, start, 1], motion_data[frame, end, 1]],
126
- [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
127
 
128
- ax.set_title(f"Frame {frame}")
129
  return ax
130
 
131
  # Create animation
132
- anim = FuncAnimation(fig, update, frames=min(100, frames), interval=1000/30)
133
 
134
  # Save animation
 
135
  anim.save(output_path, writer='ffmpeg', fps=30)
136
  plt.close()
137
 
 
138
  return output_path
139
 
140
- if __name__ == "__main__":
141
- if len(sys.argv) < 3:
142
- print("Usage: python visualize_motion.py <motion_file> <output_path>")
143
- sys.exit(1)
144
 
145
- motion_file = sys.argv[1]
146
- output_path = sys.argv[2]
147
 
148
- visualize_motion(motion_file, output_path)
149
- """)
150
-
151
- # Run the generation directly
152
- print("Running MDM generation...")
153
- generation_cmd = [
154
- "python", "-c",
155
- f"""
156
- import sys
157
- sys.path.insert(0, '.')
158
- from sample.generate import generate
159
- import os
160
-
161
- # Run the generation
162
- motion_data = generate(
163
- model_path='{absolute_checkpoint_path}',
164
- text_prompt='{text_prompt}',
165
- motion_length={motion_length},
166
- seeds=[{int(seed)}],
167
- num_samples=1
168
- )
169
-
170
- # Save the motion data for visualization
171
- import numpy as np
172
- os.makedirs('output', exist_ok=True)
173
- np.save('output/motion_data.npy', motion_data[0])
174
- print("Motion data saved to output/motion_data.npy")
175
- """
176
- ]
177
-
178
- print(f"Running generation command...")
179
- gen_result = subprocess.run(generation_cmd, capture_output=True, text=True)
180
-
181
- print("Generation output:", gen_result.stdout)
182
- if gen_result.stderr:
183
- print("Generation error:", gen_result.stderr)
184
-
185
- # Check if motion data was generated
186
- if Path("output/motion_data.npy").exists():
187
- print("Motion data generated successfully!")
188
-
189
- # Visualize the motion
190
- print("Visualizing motion...")
191
- viz_cmd = ["python", "visualize_motion.py", "output/motion_data.npy", "../output.mp4"]
192
- viz_result = subprocess.run(viz_cmd, capture_output=True, text=True)
193
-
194
- print("Visualization output:", viz_result.stdout)
195
- if viz_result.stderr:
196
- print("Visualization error:", viz_result.stderr)
197
-
198
- # Return to original directory
199
- os.chdir(original_dir)
200
-
201
- # Check if the output file exists
202
- if Path("output.mp4").exists():
203
- print("Animation created successfully!")
204
- return "output.mp4"
205
-
206
- # Return to original directory
207
- os.chdir(original_dir)
208
-
209
- print("No output file generated.")
210
  return None
 
 
 
 
 
 
 
 
 
 
211
  except Exception as e:
212
  print(f"Error generating motion: {str(e)}")
213
  print(traceback.format_exc()) # Print the full traceback
214
-
215
- # Return to original directory if an exception occurred
216
- if 'original_dir' in locals():
217
- os.chdir(original_dir)
218
-
219
  return None
220
 
221
  # Create the Gradio interface
@@ -227,8 +222,8 @@ demo = gr.Interface(
227
  gr.Number(label="Random Seed", value=0)
228
  ],
229
  outputs=gr.Video(label="Generated Motion"),
230
- title="Motion Diffusion Model (MDM)",
231
- description="Generate human motions from text descriptions using MDM"
232
  )
233
 
234
  # Launch the app
 
37
 
38
  return checkpoint_path
39
 
40
+ def inspect_mdm_repo():
41
+ """Check the structure of the MDM repository"""
42
  if not Path("motion-diffusion-model").exists():
43
+ print("MDM repository not found.")
44
+ return
45
+
46
+ print("Inspecting MDM repository structure:")
47
+ # List top-level directories and files
48
+ for item in sorted(os.listdir("motion-diffusion-model")):
49
+ path = os.path.join("motion-diffusion-model", item)
50
+ if os.path.isdir(path):
51
+ print(f"Directory: {item}")
52
+ # List files in the directory
53
+ for subitem in sorted(os.listdir(path)):
54
+ print(f" - {subitem}")
55
+ else:
56
+ print(f"File: {item}")
 
 
57
 
58
+ def create_simple_motion():
59
+ """Create a simple motion animation as a fallback"""
60
+ print("Creating a simple motion animation...")
61
+
62
+ # Create a simple visualization script
63
+ with open("simple_motion.py", "w") as f:
64
+ f.write("""
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  import numpy as np
66
  import matplotlib.pyplot as plt
67
  from matplotlib.animation import FuncAnimation
 
68
  import os
69
  from mpl_toolkits.mplot3d import Axes3D
70
 
71
+ # Create a simple walking motion
72
+ def generate_simple_motion(frames=90):
73
+ joints = 16 # Number of joints in a simplified skeleton
74
+ dims = 3 # x, y, z
75
+
76
+ motion = np.zeros((frames, joints, dims))
77
+
78
+ # Create a simple walking motion
79
+ for frame in range(frames):
80
+ t = frame / frames
81
+
82
+ # Basic forward motion
83
+ motion[frame, :, 0] = t * 2 # Move forward on X axis
84
+
85
+ # Root joint (pelvis)
86
+ motion[frame, 0, 1] = 0 # Y position
87
+ motion[frame, 0, 2] = np.sin(t * 2 * np.pi * 2) * 0.1 + 1 # Z position with slight bounce
88
+
89
+ # Spine and head (joints 1, 2, 3)
90
+ for i in range(1, 4):
91
+ motion[frame, i, 1] = 0
92
+ motion[frame, i, 2] = motion[frame, 0, 2] + i * 0.2 # Stack joints vertically
93
+
94
+ # Left leg (joints 4, 5, 6)
95
+ swing_leg_l = np.sin(t * 2 * np.pi * 2)
96
+ motion[frame, 4, 1] = 0.2 # Left hip position
97
+ motion[frame, 4, 2] = motion[frame, 0, 2] - 0.1
98
+ motion[frame, 5, 1] = 0.2
99
+ motion[frame, 5, 2] = motion[frame, 4, 2] - 0.5 + swing_leg_l * 0.3
100
+ motion[frame, 6, 1] = 0.2
101
+ motion[frame, 6, 2] = motion[frame, 5, 2] - 0.5 + swing_leg_l * 0.3
102
+
103
+ # Right leg (joints 7, 8, 9)
104
+ swing_leg_r = np.sin(t * 2 * np.pi * 2 + np.pi) # Opposite phase
105
+ motion[frame, 7, 1] = -0.2 # Right hip position
106
+ motion[frame, 7, 2] = motion[frame, 0, 2] - 0.1
107
+ motion[frame, 8, 1] = -0.2
108
+ motion[frame, 8, 2] = motion[frame, 7, 2] - 0.5 + swing_leg_r * 0.3
109
+ motion[frame, 9, 1] = -0.2
110
+ motion[frame, 9, 2] = motion[frame, 8, 2] - 0.5 + swing_leg_r * 0.3
111
+
112
+ # Left arm (joints 10, 11, 12)
113
+ swing_arm_l = np.sin(t * 2 * np.pi * 2 + np.pi) # Opposite to left leg
114
+ motion[frame, 10, 1] = 0.3 # Left shoulder position
115
+ motion[frame, 10, 2] = motion[frame, 3, 2] - 0.2
116
+ motion[frame, 11, 1] = 0.3 + swing_arm_l * 0.2
117
+ motion[frame, 11, 2] = motion[frame, 10, 2] - 0.4
118
+ motion[frame, 12, 1] = 0.3 + swing_arm_l * 0.4
119
+ motion[frame, 12, 2] = motion[frame, 11, 2] - 0.4
120
+
121
+ # Right arm (joints 13, 14, 15)
122
+ swing_arm_r = np.sin(t * 2 * np.pi * 2) # Opposite to right leg
123
+ motion[frame, 13, 1] = -0.3 # Right shoulder position
124
+ motion[frame, 13, 2] = motion[frame, 3, 2] - 0.2
125
+ motion[frame, 14, 1] = -0.3 + swing_arm_r * 0.2
126
+ motion[frame, 14, 2] = motion[frame, 13, 2] - 0.4
127
+ motion[frame, 15, 1] = -0.3 + swing_arm_r * 0.4
128
+ motion[frame, 15, 2] = motion[frame, 14, 2] - 0.4
129
+
130
+ return motion
131
+
132
+ def visualize_motion(output_path):
133
+ # Generate motion data
134
+ motion_data = generate_simple_motion()
135
 
136
  # Get dimensions
137
  frames, joints, dims = motion_data.shape
138
 
139
  # Create figure
140
+ fig = plt.figure(figsize=(10, 6))
141
  ax = fig.add_subplot(111, projection='3d')
142
 
143
+ # Define connections between joints (simplified skeleton)
144
  connections = [
145
+ (0, 1), (1, 2), (2, 3), # Spine and head
146
  (0, 4), (4, 5), (5, 6), # Left leg
147
  (0, 7), (7, 8), (8, 9), # Right leg
148
  (3, 10), (10, 11), (11, 12), # Left arm
 
153
  ax.clear()
154
 
155
  # Set axis limits
156
+ ax.set_xlim([-1, 3])
157
+ ax.set_ylim([-1, 1])
158
+ ax.set_zlim([0, 3])
159
+
160
+ # Set labels
161
+ ax.set_xlabel('X (forward)')
162
+ ax.set_ylabel('Y (sideways)')
163
+ ax.set_zlabel('Z (upward)')
164
 
165
  # Plot joints
166
  ax.scatter(motion_data[frame, :, 0],
167
+ motion_data[frame, :, 1],
168
+ motion_data[frame, :, 2], c='b', marker='o')
169
 
170
  # Plot connections
171
  for start, end in connections:
172
+ ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
173
+ [motion_data[frame, start, 1], motion_data[frame, end, 1]],
174
+ [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
 
175
 
176
+ ax.set_title(f"Walking Motion - Frame {frame}")
177
  return ax
178
 
179
  # Create animation
180
+ anim = FuncAnimation(fig, update, frames=motion_data.shape[0], interval=1000/30)
181
 
182
  # Save animation
183
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
184
  anim.save(output_path, writer='ffmpeg', fps=30)
185
  plt.close()
186
 
187
+ print(f"Animation saved to {output_path}")
188
  return output_path
189
 
190
+ # Create and visualize a simple motion
191
+ visualize_motion("output.mp4")
192
+ """)
 
193
 
194
+ # Run the script
195
+ subprocess.run(["python", "simple_motion.py"])
196
 
197
+ if os.path.exists("output.mp4"):
198
+ return "output.mp4"
199
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  return None
201
+
202
+ def text_to_motion(text_prompt, motion_length=3.0, seed=0):
203
+ """Generate motion from text prompt using MDM"""
204
+ try:
205
+ # Inspect the MDM repository structure
206
+ inspect_mdm_repo()
207
+
208
+ # As a fallback, create a simple motion animation
209
+ return create_simple_motion()
210
+
211
  except Exception as e:
212
  print(f"Error generating motion: {str(e)}")
213
  print(traceback.format_exc()) # Print the full traceback
 
 
 
 
 
214
  return None
215
 
216
  # Create the Gradio interface
 
222
  gr.Number(label="Random Seed", value=0)
223
  ],
224
  outputs=gr.Video(label="Generated Motion"),
225
+ title="Motion Diffusion Model Demo",
226
+ description="Generate human motions from text descriptions"
227
  )
228
 
229
  # Launch the app