rsax commited on
Commit
b57fa67
·
verified ·
1 Parent(s): 6e02822

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -80,13 +80,15 @@ def generate_motion(text, vqvae_model, transformer_model):
80
  with torch.no_grad():
81
  clip_features = clip_model.encode_text(text_encoded).float()
82
 
83
- motion_indices = transformer_model.sample(clip_features, False) # Ensure the input to the transformer is float
84
  pred_pose = vqvae_model.forward_decoder(motion_indices)
85
  pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
86
 
87
  return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
88
 
89
  def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
 
 
90
  fig = plt.figure(figsize=(10, 10))
91
  ax = fig.add_subplot(111, projection='3d')
92
  data = np.array(joints).T
@@ -101,7 +103,7 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
101
  ani.save(save_path, writer=PillowWriter(fps=20))
102
  plt.close(fig)
103
  return save_path
104
-
105
  def infer(text):
106
  motion_data = generate_motion(text, vqvae_model, transformer_model)
107
  gif_path = create_animation(motion_data)
@@ -109,9 +111,9 @@ def infer(text):
109
 
110
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
111
  with gr.Column():
112
- gr.Markdown("## 3D Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
113
  text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
114
- output_image = gr.Image(label="Generated Motion Animation")
115
  submit_button = gr.Button("Generate Motion")
116
 
117
  submit_button.click(
 
80
  with torch.no_grad():
81
  clip_features = clip_model.encode_text(text_encoded).float()
82
 
83
+ motion_indices = transformer_model.sample(clip_features, False)
84
  pred_pose = vqvae_model.forward_decoder(motion_indices)
85
  pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
86
 
87
  return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
88
 
89
  def create_animation(joints, title="3D Motion", save_path="static/animation.gif"):
90
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
91
+
92
  fig = plt.figure(figsize=(10, 10))
93
  ax = fig.add_subplot(111, projection='3d')
94
  data = np.array(joints).T
 
103
  ani.save(save_path, writer=PillowWriter(fps=20))
104
  plt.close(fig)
105
  return save_path
106
+
107
  def infer(text):
108
  motion_data = generate_motion(text, vqvae_model, transformer_model)
109
  gif_path = create_animation(motion_data)
 
111
 
112
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
113
  with gr.Column():
114
+ gr.Markdown("## 3D Human Motion Generation on " + ("GPU" if device == "cuda" else "CPU"))
115
  text_input = gr.Textbox(label="Describe the action", placeholder="Enter text description for the action here...")
116
+ output_image = gr.Image(label="Generated Human Motion")
117
  submit_button = gr.Button("Generate Motion")
118
 
119
  submit_button.click(