megalado commited on
Commit
5309894
·
1 Parent(s): 70c32f5

Avoid shell scripts and use direct Python commands

Browse files
Files changed (1) hide show
  1. app.py +141 -27
app.py CHANGED
@@ -8,6 +8,7 @@ import traceback
8
  import subprocess
9
  import glob
10
  import requests
 
11
 
12
  def download_checkpoint():
13
  """Download the recommended checkpoint if not present"""
@@ -64,44 +65,157 @@ def text_to_motion(text_prompt, motion_length=3.0, seed=0):
64
 
65
  # Download the recommended checkpoint
66
  checkpoint_path = download_checkpoint()
 
67
 
68
- # Write a simple run script that properly sets up the environment
69
- run_script = """
70
- #!/bin/bash
71
- cd motion-diffusion-model
72
- export PYTHONPATH=$PYTHONPATH:$(pwd)
73
- python -m sample.generate --model_path ../checkpoints/humanml_trans_enc_512.pt --text_prompt "$1" --motion_length $2 --seed $3 --num_samples 1
74
- """
75
 
76
- with open("run_mdm.sh", "w") as f:
77
- f.write(run_script)
78
 
79
- subprocess.run(["chmod", "+x", "run_mdm.sh"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Run the script
82
- cmd = ["./run_mdm.sh", text_prompt, str(motion_length), str(int(seed))]
83
- print(f"Running command: {' '.join(cmd)}")
84
- result = subprocess.run(cmd, capture_output=True, text=True)
 
85
 
86
- # Print the output for debugging
87
- print("Command output:", result.stdout)
88
- if result.stderr:
89
- print("Command error:", result.stderr)
90
 
91
- # Check for output files
92
- print("Checking for output files:")
93
- for root, dirs, files in os.walk("."):
94
- for file in files:
95
- if file.endswith(".mp4"):
96
- path = os.path.join(root, file)
97
- print(f"Found video file: {path}")
98
- return path
99
 
100
- print("No MP4 output files found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  return None
102
  except Exception as e:
103
  print(f"Error generating motion: {str(e)}")
104
  print(traceback.format_exc()) # Print the full traceback
 
 
 
 
 
105
  return None
106
 
107
  # Create the Gradio interface
 
8
  import subprocess
9
  import glob
10
  import requests
11
+ import time
12
 
13
  def download_checkpoint():
14
  """Download the recommended checkpoint if not present"""
 
65
 
66
  # Download the recommended checkpoint
67
  checkpoint_path = download_checkpoint()
68
+ absolute_checkpoint_path = os.path.abspath(checkpoint_path)
69
 
70
+ # Set up environment
71
+ original_dir = os.getcwd()
72
+ os.chdir("motion-diffusion-model")
 
 
 
 
73
 
74
+ # Add current directory to Python path
75
+ sys.path.insert(0, os.getcwd())
76
 
77
+ # Create a simple visualization script
78
+ with open("visualize_motion.py", "w") as f:
79
+ f.write("""
80
+ import numpy as np
81
+ import matplotlib.pyplot as plt
82
+ from matplotlib.animation import FuncAnimation
83
+ import sys
84
+ import os
85
+ from mpl_toolkits.mplot3d import Axes3D
86
+
87
+ def visualize_motion(motion_file, output_path):
88
+ # Load the motion data
89
+ motion_data = np.load(motion_file)
90
+
91
+ # Get dimensions
92
+ frames, joints, dims = motion_data.shape
93
+
94
+ # Create figure
95
+ fig = plt.figure(figsize=(10, 10))
96
+ ax = fig.add_subplot(111, projection='3d')
97
+
98
+ # Define connections between joints (simplified)
99
+ connections = [
100
+ (0, 1), (1, 2), (2, 3), # Spine
101
+ (0, 4), (4, 5), (5, 6), # Left leg
102
+ (0, 7), (7, 8), (8, 9), # Right leg
103
+ (3, 10), (10, 11), (11, 12), # Left arm
104
+ (3, 13), (13, 14), (14, 15) # Right arm
105
+ ]
106
+
107
+ def update(frame):
108
+ ax.clear()
109
 
110
+ # Set axis limits
111
+ max_range = np.max(np.abs(motion_data))
112
+ ax.set_xlim([-max_range, max_range])
113
+ ax.set_ylim([-max_range, max_range])
114
+ ax.set_zlim([-max_range, max_range])
115
 
116
+ # Plot joints
117
+ ax.scatter(motion_data[frame, :, 0],
118
+ motion_data[frame, :, 1],
119
+ motion_data[frame, :, 2], c='b', marker='o')
120
 
121
+ # Plot connections
122
+ for start, end in connections:
123
+ if start < joints and end < joints:
124
+ ax.plot([motion_data[frame, start, 0], motion_data[frame, end, 0]],
125
+ [motion_data[frame, start, 1], motion_data[frame, end, 1]],
126
+ [motion_data[frame, start, 2], motion_data[frame, end, 2]], 'r-')
 
 
127
 
128
+ ax.set_title(f"Frame {frame}")
129
+ return ax
130
+
131
+ # Create animation
132
+ anim = FuncAnimation(fig, update, frames=min(100, frames), interval=1000/30)
133
+
134
+ # Save animation
135
+ anim.save(output_path, writer='ffmpeg', fps=30)
136
+ plt.close()
137
+
138
+ return output_path
139
+
140
+ if __name__ == "__main__":
141
+ if len(sys.argv) < 3:
142
+ print("Usage: python visualize_motion.py <motion_file> <output_path>")
143
+ sys.exit(1)
144
+
145
+ motion_file = sys.argv[1]
146
+ output_path = sys.argv[2]
147
+
148
+ visualize_motion(motion_file, output_path)
149
+ """)
150
+
151
+ # Run the generation directly
152
+ print("Running MDM generation...")
153
+ generation_cmd = [
154
+ "python", "-c",
155
+ f"""
156
+ import sys
157
+ sys.path.insert(0, '.')
158
+ from sample.generate import generate
159
+ import os
160
+
161
+ # Run the generation
162
+ motion_data = generate(
163
+ model_path='{absolute_checkpoint_path}',
164
+ text_prompt='{text_prompt}',
165
+ motion_length={motion_length},
166
+ seeds=[{int(seed)}],
167
+ num_samples=1
168
+ )
169
+
170
+ # Save the motion data for visualization
171
+ import numpy as np
172
+ os.makedirs('output', exist_ok=True)
173
+ np.save('output/motion_data.npy', motion_data[0])
174
+ print("Motion data saved to output/motion_data.npy")
175
+ """
176
+ ]
177
+
178
+ print(f"Running generation command...")
179
+ gen_result = subprocess.run(generation_cmd, capture_output=True, text=True)
180
+
181
+ print("Generation output:", gen_result.stdout)
182
+ if gen_result.stderr:
183
+ print("Generation error:", gen_result.stderr)
184
+
185
+ # Check if motion data was generated
186
+ if Path("output/motion_data.npy").exists():
187
+ print("Motion data generated successfully!")
188
+
189
+ # Visualize the motion
190
+ print("Visualizing motion...")
191
+ viz_cmd = ["python", "visualize_motion.py", "output/motion_data.npy", "../output.mp4"]
192
+ viz_result = subprocess.run(viz_cmd, capture_output=True, text=True)
193
+
194
+ print("Visualization output:", viz_result.stdout)
195
+ if viz_result.stderr:
196
+ print("Visualization error:", viz_result.stderr)
197
+
198
+ # Return to original directory
199
+ os.chdir(original_dir)
200
+
201
+ # Check if the output file exists
202
+ if Path("output.mp4").exists():
203
+ print("Animation created successfully!")
204
+ return "output.mp4"
205
+
206
+ # Return to original directory
207
+ os.chdir(original_dir)
208
+
209
+ print("No output file generated.")
210
  return None
211
  except Exception as e:
212
  print(f"Error generating motion: {str(e)}")
213
  print(traceback.format_exc()) # Print the full traceback
214
+
215
+ # Return to original directory if an exception occurred
216
+ if 'original_dir' in locals():
217
+ os.chdir(original_dir)
218
+
219
  return None
220
 
221
  # Create the Gradio interface