Update app.py
Browse files
app.py
CHANGED
|
@@ -80,13 +80,15 @@ def generate_motion(text, vqvae_model, transformer_model):
|
|
| 80 |
with torch.no_grad():
|
| 81 |
clip_features = clip_model.encode_text(text_encoded).float()
|
| 82 |
|
| 83 |
-
motion_indices = transformer_model.sample(clip_features, False)
|
| 84 |
pred_pose = vqvae_model.forward_decoder(motion_indices)
|
| 85 |
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
|
| 86 |
|
| 87 |
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
|
| 88 |
|
| 89 |
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
|
|
|
|
|
|
|
| 90 |
fig = plt.figure(figsize=(10, 10))
|
| 91 |
ax = fig.add_subplot(111, projection='3d')
|
| 92 |
data = np.array(joints).T
|
|
@@ -101,7 +103,7 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
|
|
| 101 |
ani.save(save_path, writer=PillowWriter(fps=20))
|
| 102 |
plt.close(fig)
|
| 103 |
return save_path
|
| 104 |
-
|
| 105 |
def infer(text):
|
| 106 |
motion_data = generate_motion(text, vqvae_model, transformer_model)
|
| 107 |
gif_path = create_animation(motion_data)
|
|
@@ -109,9 +111,9 @@ def infer(text):
|
|
| 109 |
|
| 110 |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
|
| 111 |
with gr.Column():
|
| 112 |
-
gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
|
| 113 |
text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
|
| 114 |
-
output_image = gr.Image(label="Generated Motion
|
| 115 |
submit_button = gr.Button("Generate Motion")
|
| 116 |
|
| 117 |
submit_button.click(
|
|
|
|
| 80 |
with torch.no_grad():
|
| 81 |
clip_features = clip_model.encode_text(text_encoded).float()
|
| 82 |
|
| 83 |
+
motion_indices = transformer_model.sample(clip_features, False)
|
| 84 |
pred_pose = vqvae_model.forward_decoder(motion_indices)
|
| 85 |
pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
|
| 86 |
|
| 87 |
return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
|
| 88 |
|
| 89 |
def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
|
| 90 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 91 |
+
|
| 92 |
fig = plt.figure(figsize=(10, 10))
|
| 93 |
ax = fig.add_subplot(111, projection='3d')
|
| 94 |
data = np.array(joints).T
|
|
|
|
| 103 |
ani.save(save_path, writer=PillowWriter(fps=20))
|
| 104 |
plt.close(fig)
|
| 105 |
return save_path
|
| 106 |
+
|
| 107 |
def infer(text):
|
| 108 |
motion_data = generate_motion(text, vqvae_model, transformer_model)
|
| 109 |
gif_path = create_animation(motion_data)
|
|
|
|
| 111 |
|
| 112 |
with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
|
| 113 |
with gr.Column():
|
| 114 |
+
gr.Markdown("## 3D Human Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
|
| 115 |
text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
|
| 116 |
+
output_image = gr.Image(label="Generated Human Motion")
|
| 117 |
submit_button = gr.Button("Generate Motion")
|
| 118 |
|
| 119 |
submit_button.click(
|