Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,8 +10,13 @@ from VAE_model_tablets_class import VAE
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
|
| 16 |
model.eval()
|
| 17 |
|
|
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
+
num_classes = len(TabletPeriodDataset.PERIOD_INDICES)
|
| 14 |
|
| 15 |
+
class_weights = torch.load("data/class_weights_period.pt")
|
| 16 |
+
|
| 17 |
+
checkpoint_path = 'epoch=22-step=213621.ckpt'
|
| 18 |
+
vae_model = VAE.load_from_checkpoint(checkpoint_path,image_channels=1,z_dim=12, lr =0.0001, use_classification_loss=True, num_classes=num_classes,
|
| 19 |
+
loss_type="weighted", class_weights=class_weights, device = device)
|
| 20 |
model.load_state_dict(torch.load('epoch=22-step=213621.ckpt'))
|
| 21 |
model.eval()
|
| 22 |
|