rsax commited on
Commit
bd8219d
·
verified ·
1 Parent(s): c213aac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -38,7 +38,14 @@ vqvae_model = vqvae.HumanVQVAE(args,
38
  args.width,
39
  args.depth,
40
  args.dilation_growth_rate).to(device)
41
- transformer_model = trans.Text2Motion_Transformer().to(device)
 
 
 
 
 
 
 
42
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
43
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
44
  vqvae_model.eval()
 
38
  args.width,
39
  args.depth,
40
  args.dilation_growth_rate).to(device)
41
+ transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
42
+ embed_dim=1024,
43
+ clip_dim=args.clip_dim,
44
+ block_size=args.block_size,
45
+ num_layers=9,
46
+ n_head=16,
47
+ drop_out_rate=args.drop_out_rate,
48
+ fc_rate=args.ff_rate).to(device)
49
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
50
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
51
  vqvae_model.eval()