Spaces:
Build error
Build error
Commit
·
0880bdc
1
Parent(s):
28ba8c0
updated app to support CPU
Browse files
utils.py
CHANGED
|
@@ -12,12 +12,16 @@ def save_checkpoint(model, optimizer, epoch, loss, path):
|
|
| 12 |
print(f"Checkpoint saved at epoch {epoch}")
|
| 13 |
|
| 14 |
def load_checkpoint(model, optimizer, checkpoint_path):
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 17 |
if optimizer is not None:
|
| 18 |
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 19 |
start_epoch = checkpoint['epoch']
|
| 20 |
loss = checkpoint['loss']
|
|
|
|
| 21 |
return model, optimizer, start_epoch, loss
|
| 22 |
|
| 23 |
def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):
|
|
|
|
| 12 |
print(f"Checkpoint saved at epoch {epoch}")
|
| 13 |
|
| 14 |
def load_checkpoint(model, optimizer, checkpoint_path):
|
| 15 |
+
# Use map_location to load the checkpoint on CPU if CUDA is not available
|
| 16 |
+
map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 17 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 18 |
+
|
| 19 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 20 |
if optimizer is not None:
|
| 21 |
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 22 |
start_epoch = checkpoint['epoch']
|
| 23 |
loss = checkpoint['loss']
|
| 24 |
+
|
| 25 |
return model, optimizer, start_epoch, loss
|
| 26 |
|
| 27 |
def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates):
|