Spaces:
Build error
Build error
Update cycleGAN/training/trainer.py
Browse files
cycleGAN/training/trainer.py
CHANGED
|
@@ -78,7 +78,7 @@ class CGANTrainer():
|
|
| 78 |
}, os.path.join(MODEL_SAVE_DIR, f"epoch_{epoch}.pth"))
|
| 79 |
|
| 80 |
def load_checkpoints(self, file):
|
| 81 |
-
checkpoint = torch.load(file)
|
| 82 |
|
| 83 |
self.genA.load_state_dict(checkpoint['genA_state_dict'])
|
| 84 |
self.genB.load_state_dict(checkpoint['genB_state_dict'])
|
|
|
|
| 78 |
}, os.path.join(MODEL_SAVE_DIR, f"epoch_{epoch}.pth"))
|
| 79 |
|
| 80 |
def load_checkpoints(self, file):
|
| 81 |
+
checkpoint = torch.load(file, map_location=device)
|
| 82 |
|
| 83 |
self.genA.load_state_dict(checkpoint['genA_state_dict'])
|
| 84 |
self.genB.load_state_dict(checkpoint['genB_state_dict'])
|