Spaces:
Sleeping
Sleeping
File size: 1,087 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import os
import torch
def save_checkpoint(
path,
encoder,
decoder,
optimizer,
epoch,
train_loss,
val_loss
):
torch.save({
"epoch": epoch,
"encoder_state_dict": encoder.state_dict(),
"decoder_state_dict": decoder.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"train_loss": train_loss,
"val_loss": val_loss
}, path)
def load_checkpoint(
best_path,
encoder,
decoder,
optimizer,
device
):
print(f"Loading checkpoint: {best_path}")
checkpoint = torch.load(
best_path,
map_location=device
)
encoder.load_state_dict(checkpoint["encoder_state_dict"])
decoder.load_state_dict(checkpoint["decoder_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
best_val_loss = checkpoint["val_loss"]
print(
f"Resume from Epoch {start_epoch+1} | "
f"Best Val Loss: {best_val_loss:.4f}"
)
return start_epoch+1, best_val_loss |