DaniKaEp commited on
Commit
c25bef4
·
verified ·
1 Parent(s): 049c601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -10,8 +10,13 @@ from VAE_model_tablets_class import VAE
10
 
11
  import gradio as gr
12
 
 
13
 
14
- model = VAE()
 
 
 
 
15
  model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
16
  model.eval()
17
 
 
10
 
11
  import gradio as gr
12
 
13
+ num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
14
 
15
+ class_weights = torch.load("data/class_weights_period.pt")
16
+
17
+ checkpoint_path = 'epoch=22-step=213621.ckpt'
18
+ vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
19
+ loss_type="weighted", class_weights=class_weights, device = device)
20
  model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
21
  model.eval()
22