rsax commited on
Commit
ad7134f
·
verified ·
1 Parent(s): f6c4439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -49,12 +49,12 @@ transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
49
  fc_rate=args.ff_rate).to(device)
50
 
51
  vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)
52
- transformed_state_dict = {}
53
  for k, v in vqvae_checkpoint['net'].items():
54
- new_key = k.replace("vqvae.", "")
55
- transformed_state_dict[new_key] = v
56
 
57
- vqvae_model.load_state_dict(transformed_state_dict, strict=False)
58
 
59
  transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device)
60
  transformed_transformer_state_dict = {}
@@ -74,7 +74,7 @@ def generate_motion(text, vqvae_model, transformer_model):
74
  clip_text = [text]
75
  text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
76
  with torch.no_grad():
77
- motion_indices = transformer_model.sample(text_encoded, False)
78
  pred_pose = vqvae_model.forward_decoder(motion_indices)
79
  pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
80
  return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
@@ -94,11 +94,6 @@ def create_animation(joints, title="3D Motion", save_path="static/animation.gif"
94
  ani.save(save_path, writer=PillowWriter(fps=20))
95
  plt.close(fig)
96
  return save_path
97
-
98
- examples = [
99
- "A person doing a kick",
100
- "A person is dancing ballet",
101
- ]
102
 
103
  def infer(text):
104
  motion_data = generate_motion(text, vqvae_model, transformer_model)
 
49
  fc_rate=args.ff_rate).to(device)
50
 
51
  vqvae_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device)
52
+ transformed_vqvae_state_dict = {}
53
  for k, v in vqvae_checkpoint['net'].items():
54
+ new_key = k.replace("vqvae.", "")
55
+ transformed_vqvae_state_dict[new_key] = v
56
 
57
+ vqvae_model.load_state_dict(transformed_vqvae_state_dict, strict=False)
58
 
59
  transformer_checkpoint = torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_best_fid.pth", map_location=device)
60
  transformed_transformer_state_dict = {}
 
74
  clip_text = [text]
75
  text_encoded = clip.tokenize(clip_text, truncate=True).to(device)
76
  with torch.no_grad():
77
+ motion_indices = transformer_model.sample(text_encoded.float(), False)
78
  pred_pose = vqvae_model.forward_decoder(motion_indices)
79
  pred_xyz = recover_from_ric((pred_pose * std + mean).float(), 22)
80
  return pred_xyz.cpu().numpy().reshape(-1, 22, 3)
 
94
  ani.save(save_path, writer=PillowWriter(fps=20))
95
  plt.close(fig)
96
  return save_path
 
 
 
 
 
97
 
98
  def infer(text):
99
  motion_data = generate_motion(text, vqvae_model, transformer_model)