rsax commited on
Commit
1703337
·
verified ·
1 Parent(s): ec8b3ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -17
app.py CHANGED
@@ -30,23 +30,8 @@ class Args:
30
  args = Args()
31
 
32
 
33
- vqvae_model = vqvae.HumanVQVAE(args,
34
- args.nb_code,
35
- args.code_dim,
36
- args.output_emb_width,
37
- args.down_t,
38
- args.stride_t,
39
- args.width,
40
- args.depth,
41
- args.dilation_growth_rate).to(device)
42
- transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
43
- embed_dim=1024,
44
- clip_dim=args.clip_dim,
45
- block_size=args.block_size,
46
- num_layers=9,
47
- n_head=16,
48
- drop_out_rate=args.drop_out_rate,
49
- fc_rate=args.ff_rate).to(device)
50
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
51
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
52
  vqvae_model.eval()
 
30
  args = Args()
31
 
32
 
33
+ vqvae_model = vqvae.HumanVQVAE().to(device)
34
+ transformer_model = trans.Text2Motion_Transformer().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
36
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
37
  vqvae_model.eval()