Spaces:
Runtime error
Runtime error
Laishram Pongthangamba Meitei commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,8 +17,8 @@ torch.cuda.empty_cache()
|
|
| 17 |
|
| 18 |
## Load autoencoder
|
| 19 |
|
| 20 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
-
|
| 22 |
autoencoderkl = AutoencoderKL(
|
| 23 |
spatial_dims=2,
|
| 24 |
in_channels=1,
|
|
@@ -33,7 +33,7 @@ autoencoderkl = AutoencoderKL(
|
|
| 33 |
root_dir = "models"
|
| 34 |
PATH_auto = f'{root_dir}/auto_encoder_model.pt'
|
| 35 |
|
| 36 |
-
autoencoderkl.load_state_dict(torch.load(PATH_auto))
|
| 37 |
autoencoderkl = autoencoderkl.to(device)
|
| 38 |
|
| 39 |
#### Load unet and embedings
|
|
@@ -60,8 +60,8 @@ embed = torch.nn.Embedding(num_embeddings=6, embedding_dim=embedding_dimension,
|
|
| 60 |
PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
|
| 61 |
PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
|
| 62 |
|
| 63 |
-
unet.load_state_dict(torch.load(PATH_unet_condition))
|
| 64 |
-
embed.load_state_dict(torch.load(PATH_embed_condition))
|
| 65 |
|
| 66 |
# unet.load_state_dict(checkpoint['model_state_dict'])
|
| 67 |
# embed.load_state_dict(checkpoint['embed_state_dict'])
|
|
|
|
| 17 |
|
| 18 |
## Load autoencoder
|
| 19 |
|
| 20 |
+
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 21 |
+
device = torch.device('cpu')
|
| 22 |
autoencoderkl = AutoencoderKL(
|
| 23 |
spatial_dims=2,
|
| 24 |
in_channels=1,
|
|
|
|
| 33 |
root_dir = "models"
|
| 34 |
PATH_auto = f'{root_dir}/auto_encoder_model.pt'
|
| 35 |
|
| 36 |
+
autoencoderkl.load_state_dict(torch.load(PATH_auto,map_location=device))
|
| 37 |
autoencoderkl = autoencoderkl.to(device)
|
| 38 |
|
| 39 |
#### Load unet and embedings
|
|
|
|
| 60 |
PATH_unet_condition = f'{root_dir}/unet_latent_space_model_condition.pt'
|
| 61 |
PATH_embed_condition = f'{root_dir}/embed_latent_space_model_condition.pt'
|
| 62 |
|
| 63 |
+
unet.load_state_dict(torch.load(PATH_unet_condition,map_location=device))
|
| 64 |
+
embed.load_state_dict(torch.load(PATH_embed_condition,map_location=device))
|
| 65 |
|
| 66 |
# unet.load_state_dict(checkpoint['model_state_dict'])
|
| 67 |
# embed.load_state_dict(checkpoint['embed_state_dict'])
|