Spaces:
Build error
Build error
| import torch | |
| def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"): | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss | |
| } | |
| torch.save(checkpoint, checkpoint_path) | |
| print(f"Checkpoint saved at epoch {epoch}") | |
| def load_checkpoint(model, optimizer, checkpoint_path="checkpoint.pth"): | |
| checkpoint = torch.load(checkpoint_path, weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| epoch = checkpoint['epoch'] | |
| loss = checkpoint['loss'] | |
| print(f"Checkpoint loaded, resuming from epoch {epoch}") | |
| return model, optimizer, loss | |