DaniKaEp commited on
Commit
6d53984
·
verified ·
1 Parent(s): de1fe70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -10,6 +10,8 @@ from VAE_model_tablets_class import VAE
10
 
11
  import gradio as gr
12
 
 
 
13
  num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
14
 
15
  class_weights = torch.load("class_weights_period.pt")
 
10
 
11
  import gradio as gr
12
 
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+
15
  num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
16
 
17
  class_weights = torch.load("class_weights_period.pt")