rsax commited on
Commit
ec8b3ca
·
verified ·
1 Parent(s): 7a1fa76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -2
app.py CHANGED
@@ -20,8 +20,33 @@ warnings.filterwarnings('ignore')
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- vqvae_model = vqvae.VQVAE().to(device)
24
- transformer_model = trans.Text2Motion_Transformer().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
26
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
27
  vqvae_model.eval()
 
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # Create an args object for initializing the HumanVQVAE
24
+ class Args:
25
+ def __init__(self):
26
+ self.dataname = 't2m' # example property, adjust as needed
27
+ self.quantizer = 'ema' # Set the type of quantizer used in VQVAE_251
28
+ # Add other properties required by HumanVQVAE initialization
29
+
30
+ args = Args()
31
+
32
+
33
+ vqvae_model = vqvae.HumanVQVAE(args,
34
+ args.nb_code,
35
+ args.code_dim,
36
+ args.output_emb_width,
37
+ args.down_t,
38
+ args.stride_t,
39
+ args.width,
40
+ args.depth,
41
+ args.dilation_growth_rate).to(device)
42
+ transformer_model = trans.Text2Motion_Transformer(num_vq=args.nb_code,
43
+ embed_dim=1024,
44
+ clip_dim=args.clip_dim,
45
+ block_size=args.block_size,
46
+ num_layers=9,
47
+ n_head=16,
48
+ drop_out_rate=args.drop_out_rate,
49
+ fc_rate=args.ff_rate).to(device)
50
  vqvae_model.load_state_dict(torch.load("output/VQVAE_imp_resnet_100k_hml3d/net_last.pth", map_location=device))
51
  transformer_model.load_state_dict(torch.load("output/net_best_fid.pth", map_location=device))
52
  vqvae_model.eval()