Update unet/DDPM_Unet_sample.py
Browse files- unet/DDPM_Unet_sample.py +1 -1
unet/DDPM_Unet_sample.py
CHANGED
|
@@ -911,7 +911,7 @@ save_each = 1
|
|
| 911 |
diffusion_model = diffusion_model.to(device)
|
| 912 |
|
| 913 |
last_trained_path = '/content/DDPM_ResNet_Unet/unet/model/epoch_30.pth'
|
| 914 |
-
diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])
|
| 915 |
|
| 916 |
sample_path = '/content/DDPM_ResNet_Unet/unet/sample'
|
| 917 |
|
|
|
|
| 911 |
diffusion_model = diffusion_model.to(device)
|
| 912 |
|
| 913 |
last_trained_path = '/content/DDPM_ResNet_Unet/unet/model/epoch_30.pth'
|
| 914 |
+
diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path), map_location=device)['model'])
|
| 915 |
|
| 916 |
sample_path = '/content/DDPM_ResNet_Unet/unet/sample'
|
| 917 |
|