rsax commited on
Commit
c213aac
·
verified ·
1 Parent(s): 1703337

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -20,17 +20,24 @@ warnings.filterwarnings('ignore')
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- # Create an args object for initializing the HumanVQVAE
24
- class Args:
25
- def __init__(self):
26
- self.dataname = 't2m' # example property, adjust as needed
27
- self.quantizer = 'ema' # Set the type of quantizer used in VQVAE_251
28
- # Add other properties required by HumanVQVAE initialization
29
 
30
- args = Args()
 
 
 
 
 
31
 
32
-
33
- vqvae_model = vqvae.HumanVQVAE().to(device)
 
 
 
 
 
 
 
34
  transformer_model = trans.Text2Motion_Transformer().to(device)
35
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
36
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
 
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ args = option_trans.get_args_parser()
 
 
 
 
 
24
 
25
+ args.dataname = 't2m'
26
+ args.resume_pth = './output/VQVAE_imp_resnet_100k_hml3d/net_last.pth'
27
+ args.resume_trans = './output/net_best_fid.pth'
28
+ args.down_t = 2
29
+ args.depth = 3
30
+ args.block_size = 51
31
 
32
+ vqvae_model = vqvae.HumanVQVAE(args,
33
+ args.nb_code,
34
+ args.code_dim,
35
+ args.output_emb_width,
36
+ args.down_t,
37
+ args.stride_t,
38
+ args.width,
39
+ args.depth,
40
+ args.dilation_growth_rate).to(device)
41
  transformer_model = trans.Text2Motion_Transformer().to(device)
42
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
43
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))