ndbao2002 commited on
Commit
283987a
·
verified ·
1 Parent(s): 731bc80

Update unet/DDPM_Unet_sample.py

Browse files
Files changed (1) hide show
  1. unet/DDPM_Unet_sample.py +1 -1
unet/DDPM_Unet_sample.py CHANGED
@@ -911,7 +911,7 @@ save_each = 1
911
  diffusion_model = diffusion_model.to(device)
912
 
913
  last_trained_path = '/content/DDPM_ResNet_Unet/unet/model/epoch_30.pth'
914
- diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])
915
 
916
  sample_path = '/content/DDPM_ResNet_Unet/unet/sample'
917
 
 
911
  diffusion_model = diffusion_model.to(device)
912
 
913
  last_trained_path = '/content/DDPM_ResNet_Unet/unet/model/epoch_30.pth'
914
+ diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path), map_location=device)['model'])
915
 
916
  sample_path = '/content/DDPM_ResNet_Unet/unet/sample'
917