megalado commited on
Commit
bb5870d
·
1 Parent(s): 4c3610c

Create standalone inference script for MDM demo

Browse files
Files changed (1) hide show
  1. app.py +141 -39
app.py CHANGED
@@ -7,47 +7,156 @@ from pathlib import Path
7
  import traceback
8
  import subprocess
9
  import glob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
12
  """Generate motion from text prompt using MDM"""
13
  try:
14
- # First, check if we need to clone the repository
15
- if not Path("motion-diffusion-model").exists():
16
- print("Cloning Motion Diffusion Model repository...")
17
- subprocess.run(["git", "clone", "https://github.com/GuyTevet/motion-diffusion-model.git"])
18
- print("Installing Spacy language model...")
19
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
20
-
21
- # Let's examine the repository structure
22
- print("Repository contents:")
23
- for path in glob.glob("motion-diffusion-model/*"):
24
- print(f"- {path}")
25
-
26
- if Path("motion-diffusion-model/sample").exists():
27
- print("Sample directory contents:")
28
- for path in glob.glob("motion-diffusion-model/sample/*"):
29
- print(f"- {path}")
30
 
31
  # Get absolute path to the checkpoint
32
  checkpoint_path = os.path.abspath("checkpoints/mld_humanml.pt")
33
  print(f"Checkpoint path: {checkpoint_path}")
34
 
35
- # Set the working directory to the repository
36
- os.chdir("motion-diffusion-model")
37
-
38
- # Let's see what files are in the current directory
39
- print("Current directory contents:")
40
- for path in glob.glob("*"):
41
- print(f"- {path}")
42
-
43
- # Try to use the python module directly
44
  cmd = [
45
  "python",
46
- "-m", "sample.sample_text",
47
  "--model_path", checkpoint_path,
48
  "--text_prompt", text_prompt,
49
  "--motion_length", str(motion_length),
50
- "--num_samples", "1",
51
  "--seed", str(int(seed))
52
  ]
53
 
@@ -59,19 +168,12 @@ def text_to_motion(text_prompt, motion_length=3.0, seed=0):
59
  if result.stderr:
60
  print("Command error:", result.stderr)
61
 
62
- # Return to the original directory
63
- os.chdir("..")
64
-
65
- # Check for output files
66
- print("Checking for output files:")
67
- for root, dirs, files in os.walk("."):
68
- for file in files:
69
- if file.endswith(".mp4"):
70
- path = os.path.join(root, file)
71
- print(f"Found video file: {path}")
72
- return path
73
 
74
- print("No MP4 output files found.")
75
  return None
76
  except Exception as e:
77
  print(f"Error generating motion: {str(e)}")
 
7
  import traceback
8
  import subprocess
9
  import glob
10
+ import requests
11
+
12
+ def setup_mdm():
13
+ """Set up the MDM repository and files"""
14
+ # Create a simple inference script
15
+ inference_script_content = """
16
+ import torch
17
+ import numpy as np
18
+ import argparse
19
+ import os
20
+ import imageio
21
+ import matplotlib.pyplot as plt
22
+ from mpl_toolkits.mplot3d import Axes3D
23
+ from matplotlib.animation import FuncAnimation
24
+
25
+ # Parse arguments
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--model_path", type=str, required=True)
28
+ parser.add_argument("--text_prompt", type=str, required=True)
29
+ parser.add_argument("--motion_length", type=float, default=3.0)
30
+ parser.add_argument("--seed", type=int, default=0)
31
+ args = parser.parse_args()
32
+
33
+ # Mock function to generate simple motion data for testing
34
+ def generate_mock_motion(text_prompt, motion_length, seed):
35
+ np.random.seed(seed)
36
+ print(f"Generating motion for: {text_prompt}")
37
+ # Create a simple walking motion
38
+ frames = int(motion_length * 30) # 30 fps
39
+ joints = 24 # Number of joints in a typical skeleton
40
+ dimensions = 3 # x, y, z
41
+
42
+ motion = np.zeros((frames, joints, dimensions))
43
+
44
+ # Create a simple walking motion - pendulum motion for legs and arms
45
+ for frame in range(frames):
46
+ t = frame / frames
47
+
48
+ # Basic forward motion
49
+ motion[frame, :, 0] = t * 2 # Move forward on X axis
50
+
51
+ # Leg and arm swing
52
+ swing = np.sin(t * 2 * np.pi * 2) # Two cycles over the motion length
53
+
54
+ # Left leg, right leg, left arm, right arm
55
+ joint_indices = [4, 7, 16, 20]
56
+ for ji, joint_idx in enumerate(joint_indices):
57
+ # Alternate phase for left/right sides
58
+ phase = 0 if ji % 2 == 0 else np.pi
59
+ motion[frame, joint_idx, 2] = np.sin(t * 2 * np.pi * 2 + phase) * 0.5
60
+
61
+ return motion
62
+
63
+ # Visualize the motion
64
+ def visualize_motion(motion_data, output_path):
65
+ frames, joints, dims = motion_data.shape
66
+
67
+ # Create a simple stick figure animation
68
+ fig = plt.figure(figsize=(10, 10))
69
+ ax = fig.add_subplot(111, projection='3d')
70
+
71
+ # Define connections between joints (simplified)
72
+ connections = [
73
+ (0, 1), (1, 2), (2, 3), # Spine
74
+ (0, 4), (4, 5), (5, 6), # Left leg
75
+ (0, 7), (7, 8), (8, 9), # Right leg
76
+ (3, 10), (10, 11), (11, 12), # Left arm
77
+ (3, 13), (13, 14), (14, 15) # Right arm
78
+ ]
79
+
80
+ def update(frame):
81
+ ax.clear()
82
+
83
+ # Set the axis limits
84
+ ax.set_xlim([-2, 4])
85
+ ax.set_ylim([-2, 2])
86
+ ax.set_zlim([-2, 2])
87
+
88
+ # Plot joints
89
+ ax.scatter(motion_data[frame, :, 0],
90
+ motion_data[frame, :, 1],
91
+ motion_data[frame, :, 2], c='b', marker='o')
92
+
93
+ # Plot connections
94
+ for start, end in connections:
95
+ ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
96
+ [motion_data[frame, start, 1], motion_data[frame, end, 1]],
97
+ [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
98
+
99
+ ax.set_title(f"Frame {frame}")
100
+ return ax
101
+
102
+ # Create animation
103
+ anim = FuncAnimation(fig, update, frames=frames, interval=1000/30)
104
+
105
+ # Save animation to mp4
106
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
107
+ anim.save(output_path, writer='ffmpeg', fps=30)
108
+ plt.close()
109
+
110
+ return output_path
111
+
112
+ # Main function
113
+ def main():
114
+ print(f"Processing text prompt: {args.text_prompt}")
115
+ print(f"Using model: {args.model_path}")
116
+ print(f"Motion length: {args.motion_length}")
117
+ print(f"Seed: {args.seed}")
118
+
119
+ # Generate motion
120
+ motion_data = generate_mock_motion(
121
+ args.text_prompt,
122
+ args.motion_length,
123
+ args.seed
124
+ )
125
+
126
+ # Visualize and save
127
+ output_path = "output.mp4"
128
+ visualize_motion(motion_data, output_path)
129
+ print(f"Saved animation to {output_path}")
130
+
131
+ return output_path
132
+
133
+ if __name__ == "__main__":
134
+ main()
135
+ """
136
+
137
+ # Write the inference script to file
138
+ with open("mdm_inference.py", "w") as f:
139
+ f.write(inference_script_content)
140
+
141
+ print("Created MDM inference script")
142
 
143
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
144
  """Generate motion from text prompt using MDM"""
145
  try:
146
+ # Ensure MDM scripts are set up
147
+ setup_mdm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # Get absolute path to the checkpoint
150
  checkpoint_path = os.path.abspath("checkpoints/mld_humanml.pt")
151
  print(f"Checkpoint path: {checkpoint_path}")
152
 
153
+ # Run the inference script
 
 
 
 
 
 
 
 
154
  cmd = [
155
  "python",
156
+ "mdm_inference.py",
157
  "--model_path", checkpoint_path,
158
  "--text_prompt", text_prompt,
159
  "--motion_length", str(motion_length),
 
160
  "--seed", str(int(seed))
161
  ]
162
 
 
168
  if result.stderr:
169
  print("Command error:", result.stderr)
170
 
171
+ # Check if the output file exists
172
+ if Path("output.mp4").exists():
173
+ print("Found output.mp4 file")
174
+ return "output.mp4"
 
 
 
 
 
 
 
175
 
176
+ print("No output.mp4 file found.")
177
  return None
178
  except Exception as e:
179
  print(f"Error generating motion: {str(e)}")