megalado commited on
Commit
70c32f5
·
1 Parent(s): bb5870d

Update app to use proper MDM checkpoint and setup

Browse files
Files changed (1) hide show
  1. app.py +65 -141
app.py CHANGED
@@ -9,157 +9,77 @@ import subprocess
9
  import glob
10
  import requests
11
 
12
- def setup_mdm():
13
- """Set up the MDM repository and files"""
14
- # Create a simple inference script
15
- inference_script_content = """
16
- import torch
17
- import numpy as np
18
- import argparse
19
- import os
20
- import imageio
21
- import matplotlib.pyplot as plt
22
- from mpl_toolkits.mplot3d import Axes3D
23
- from matplotlib.animation import FuncAnimation
24
-
25
- # Parse arguments
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--model_path", type=str, required=True)
28
- parser.add_argument("--text_prompt", type=str, required=True)
29
- parser.add_argument("--motion_length", type=float, default=3.0)
30
- parser.add_argument("--seed", type=int, default=0)
31
- args = parser.parse_args()
32
-
33
- # Mock function to generate simple motion data for testing
34
- def generate_mock_motion(text_prompt, motion_length, seed):
35
- np.random.seed(seed)
36
- print(f"Generating motion for: {text_prompt}")
37
- # Create a simple walking motion
38
- frames = int(motion_length * 30) # 30 fps
39
- joints = 24 # Number of joints in a typical skeleton
40
- dimensions = 3 # x, y, z
41
 
42
- motion = np.zeros((frames, joints, dimensions))
 
43
 
44
- # Create a simple walking motion - pendulum motion for legs and arms
45
- for frame in range(frames):
46
- t = frame / frames
47
-
48
- # Basic forward motion
49
- motion[frame, :, 0] = t * 2 # Move forward on X axis
50
 
51
- # Leg and arm swing
52
- swing = np.sin(t * 2 * np.pi * 2) # Two cycles over the motion length
 
53
 
54
- # Left leg, right leg, left arm, right arm
55
- joint_indices = [4, 7, 16, 20]
56
- for ji, joint_idx in enumerate(joint_indices):
57
- # Alternate phase for left/right sides
58
- phase = 0 if ji % 2 == 0 else np.pi
59
- motion[frame, joint_idx, 2] = np.sin(t * 2 * np.pi * 2 + phase) * 0.5
60
 
61
- return motion
62
-
63
- # Visualize the motion
64
- def visualize_motion(motion_data, output_path):
65
- frames, joints, dims = motion_data.shape
66
 
67
- # Create a simple stick figure animation
68
- fig = plt.figure(figsize=(10, 10))
69
- ax = fig.add_subplot(111, projection='3d')
70
-
71
- # Define connections between joints (simplified)
72
- connections = [
73
- (0, 1), (1, 2), (2, 3), # Spine
74
- (0, 4), (4, 5), (5, 6), # Left leg
75
- (0, 7), (7, 8), (8, 9), # Right leg
76
- (3, 10), (10, 11), (11, 12), # Left arm
77
- (3, 13), (13, 14), (14, 15) # Right arm
78
- ]
79
-
80
- def update(frame):
81
- ax.clear()
82
-
83
- # Set the axis limits
84
- ax.set_xlim([-2, 4])
85
- ax.set_ylim([-2, 2])
86
- ax.set_zlim([-2, 2])
87
-
88
- # Plot joints
89
- ax.scatter(motion_data[frame, :, 0],
90
- motion_data[frame, :, 1],
91
- motion_data[frame, :, 2], c='b', marker='o')
92
 
93
- # Plot connections
94
- for start, end in connections:
95
- ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
96
- [motion_data[frame, start, 1], motion_data[frame, end, 1]],
97
- [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
 
 
98
 
99
- ax.set_title(f"Frame {frame}")
100
- return ax
101
-
102
- # Create animation
103
- anim = FuncAnimation(fig, update, frames=frames, interval=1000/30)
104
-
105
- # Save animation to mp4
106
- os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
107
- anim.save(output_path, writer='ffmpeg', fps=30)
108
- plt.close()
109
-
110
- return output_path
111
-
112
- # Main function
113
- def main():
114
- print(f"Processing text prompt: {args.text_prompt}")
115
- print(f"Using model: {args.model_path}")
116
- print(f"Motion length: {args.motion_length}")
117
- print(f"Seed: {args.seed}")
118
-
119
- # Generate motion
120
- motion_data = generate_mock_motion(
121
- args.text_prompt,
122
- args.motion_length,
123
- args.seed
124
- )
125
-
126
- # Visualize and save
127
- output_path = "output.mp4"
128
- visualize_motion(motion_data, output_path)
129
- print(f"Saved animation to {output_path}")
130
-
131
- return output_path
132
-
133
- if __name__ == "__main__":
134
- main()
135
- """
136
-
137
- # Write the inference script to file
138
- with open("mdm_inference.py", "w") as f:
139
- f.write(inference_script_content)
140
-
141
- print("Created MDM inference script")
142
 
143
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
144
  """Generate motion from text prompt using MDM"""
145
  try:
146
- # Ensure MDM scripts are set up
147
- setup_mdm()
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Get absolute path to the checkpoint
150
- checkpoint_path = os.path.abspath("checkpoints/mld_humanml.pt")
151
- print(f"Checkpoint path: {checkpoint_path}")
152
 
153
- # Run the inference script
154
- cmd = [
155
- "python",
156
- "mdm_inference.py",
157
- "--model_path", checkpoint_path,
158
- "--text_prompt", text_prompt,
159
- "--motion_length", str(motion_length),
160
- "--seed", str(int(seed))
161
- ]
162
 
 
 
163
  print(f"Running command: {' '.join(cmd)}")
164
  result = subprocess.run(cmd, capture_output=True, text=True)
165
 
@@ -168,12 +88,16 @@ def text_to_motion(text_prompt, motion_length=3.0, seed=0):
168
  if result.stderr:
169
  print("Command error:", result.stderr)
170
 
171
- # Check if the output file exists
172
- if Path("output.mp4").exists():
173
- print("Found output.mp4 file")
174
- return "output.mp4"
 
 
 
 
175
 
176
- print("No output.mp4 file found.")
177
  return None
178
  except Exception as e:
179
  print(f"Error generating motion: {str(e)}")
 
9
  import glob
10
  import requests
11
 
12
+ def download_checkpoint():
13
+ """Download the recommended checkpoint if not present"""
14
+ # Create checkpoints directory
15
+ os.makedirs("checkpoints", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Define the target checkpoint path
18
+ checkpoint_path = "checkpoints/humanml_trans_enc_512.pt"
19
 
20
+ if not Path(checkpoint_path).exists():
21
+ print(f"Downloading checkpoint to {checkpoint_path}...")
22
+ # URL for the checkpoint from HuggingFace or direct link
23
+ url = "https://huggingface.co/spaces/mohaed/testMDM/resolve/main/checkpoints/mld_humanml.pt"
 
 
24
 
25
+ # Download the file
26
+ response = requests.get(url, stream=True)
27
+ response.raise_for_status()
28
 
29
+ with open(checkpoint_path, 'wb') as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
+ f.write(chunk)
 
 
 
32
 
33
+ print(f"Checkpoint downloaded to {checkpoint_path}")
34
+ else:
35
+ print(f"Checkpoint already exists at {checkpoint_path}")
 
 
36
 
37
+ return checkpoint_path
38
+
39
+ def clone_mdm_repo():
40
+ """Clone the MDM repository if not present"""
41
+ if not Path("motion-diffusion-model").exists():
42
+ print("Cloning Motion Diffusion Model repository...")
43
+ subprocess.run(["git", "clone", "https://github.com/GuyTevet/motion-diffusion-model.git"])
44
+ print("Installing Spacy language model...")
45
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Download additional required files
48
+ print("Downloading additional required files...")
49
+ os.chdir("motion-diffusion-model")
50
+ subprocess.run(["bash", "prepare/download_smpl_files.sh"])
51
+ subprocess.run(["bash", "prepare/download_glove.sh"])
52
+ subprocess.run(["bash", "prepare/download_t2m_evaluators.sh"])
53
+ os.chdir("..")
54
 
55
+ print("MDM repository setup complete")
56
+ else:
57
+ print("MDM repository already exists")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def text_to_motion(text_prompt, motion_length=3.0, seed=0):
60
  """Generate motion from text prompt using MDM"""
61
  try:
62
+ # Clone the MDM repository
63
+ clone_mdm_repo()
64
+
65
+ # Download the recommended checkpoint
66
+ checkpoint_path = download_checkpoint()
67
+
68
+ # Write a simple run script that properly sets up the environment
69
+ run_script = """
70
+ #!/bin/bash
71
+ cd motion-diffusion-model
72
+ export PYTHONPATH=$PYTHONPATH:$(pwd)
73
+ python -m sample.generate --model_path ../checkpoints/humanml_trans_enc_512.pt --text_prompt "$1" --motion_length $2 --seed $3 --num_samples 1
74
+ """
75
 
76
+ with open("run_mdm.sh", "w") as f:
77
+ f.write(run_script)
 
78
 
79
+ subprocess.run(["chmod", "+x", "run_mdm.sh"])
 
 
 
 
 
 
 
 
80
 
81
+ # Run the script
82
+ cmd = ["./run_mdm.sh", text_prompt, str(motion_length), str(int(seed))]
83
  print(f"Running command: {' '.join(cmd)}")
84
  result = subprocess.run(cmd, capture_output=True, text=True)
85
 
 
88
  if result.stderr:
89
  print("Command error:", result.stderr)
90
 
91
+ # Check for output files
92
+ print("Checking for output files:")
93
+ for root, dirs, files in os.walk("."):
94
+ for file in files:
95
+ if file.endswith(".mp4"):
96
+ path = os.path.join(root, file)
97
+ print(f"Found video file: {path}")
98
+ return path
99
 
100
+ print("No MP4 output files found.")
101
  return None
102
  except Exception as e:
103
  print(f"Error generating motion: {str(e)}")