nachi1326 commited on
Commit
68d8709
·
verified ·
1 Parent(s): 61c04d4

Update cycleGAN/training/trainer.py

Browse files
Files changed (1) hide show
  1. cycleGAN/training/trainer.py +1 -1
cycleGAN/training/trainer.py CHANGED
@@ -78,7 +78,7 @@ class CGANTrainer():
78
  }, os.path.join(MODEL_SAVE_DIR, f"epoch_{epoch}.pth"))
79
 
80
  def load_checkpoints(self, file):
81
- checkpoint = torch.load(file)
82
 
83
  self.genA.load_state_dict(checkpoint['genA_state_dict'])
84
  self.genB.load_state_dict(checkpoint['genB_state_dict'])
 
78
  }, os.path.join(MODEL_SAVE_DIR, f"epoch_{epoch}.pth"))
79
 
80
  def load_checkpoints(self, file):
81
+ checkpoint = torch.load(file, map_location=device)
82
 
83
  self.genA.load_state_dict(checkpoint['genA_state_dict'])
84
  self.genB.load_state_dict(checkpoint['genB_state_dict'])