Spaces:
Build error
Build error
added utils.py
Browse files
utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def save_model(model, optimizer, epoch, loss, directory, model_name='model', **kwargs):
|
| 9 |
+
"""
|
| 10 |
+
Save a PyTorch model checkpoint.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
model: Trained model.
|
| 14 |
+
optimizer: Optimizer used for training.
|
| 15 |
+
epoch: The last epoch the model was trained on.
|
| 16 |
+
loss: The last loss recorded during training.
|
| 17 |
+
directory: The directory where to save the model.
|
| 18 |
+
model_name: Base name for the model file, defaults to 'model'.
|
| 19 |
+
kwargs: Additional keyword arguments representing metrics to be included in the filename.
|
| 20 |
+
To use the function, you would do something like this:
|
| 21 |
+
>>>save_checkpoint(model, optimizer, epoch, loss, './model_dir', f1_score=val_f1score)
|
| 22 |
+
"""
|
| 23 |
+
# Create the directory if it does not exist
|
| 24 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# Create the filename
|
| 27 |
+
metrics_str = '_'.join(f'{key}={value:.4f}' for key, value in kwargs.items())
|
| 28 |
+
filename = f'{directory}/{model_name}_epoch={epoch}_loss={loss:.4f}_{metrics_str}.pth'
|
| 29 |
+
|
| 30 |
+
# Save the model checkpoint
|
| 31 |
+
torch.save({
|
| 32 |
+
'epoch': epoch,
|
| 33 |
+
'model_state_dict': model.state_dict(),
|
| 34 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 35 |
+
'loss': loss,
|
| 36 |
+
**kwargs
|
| 37 |
+
}, filename)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_device() -> torch.device:
|
| 41 |
+
"""
|
| 42 |
+
Retrieves the appropriate Torch device for running computations.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
torch.device: The Torch device to be used for computations.
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
None
|
| 49 |
+
|
| 50 |
+
Examples:
|
| 51 |
+
>>> device = get_device()
|
| 52 |
+
>>> print(device)
|
| 53 |
+
cuda
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
if torch.cuda.is_available():
|
| 57 |
+
device = "cuda" # NVIDIA GPU
|
| 58 |
+
elif torch.backends.mps.is_available():
|
| 59 |
+
device = "mps" # Apple GPU
|
| 60 |
+
else:
|
| 61 |
+
device = "cpu" # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available
|
| 62 |
+
# print(f"Using {device} device")
|
| 63 |
+
return torch.device(device)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_checkpoint(model, optimizer, filename):
|
| 67 |
+
"""
|
| 68 |
+
Load a PyTorch model checkpoint.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
model: Model to load the weights into.
|
| 72 |
+
optimizer: Optimizer to load the state into.
|
| 73 |
+
filename: The path of the checkpoint file.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
The epoch at which training was stopped, the last loss recorded, and any additional metrics.
|
| 77 |
+
"""
|
| 78 |
+
checkpoint = torch.load(filename)
|
| 79 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 80 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 81 |
+
epoch = checkpoint['epoch']
|
| 82 |
+
loss = checkpoint['loss']
|
| 83 |
+
|
| 84 |
+
# Extract additional metrics
|
| 85 |
+
metrics = {key: value for key, value in checkpoint.items() if
|
| 86 |
+
key not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']}
|
| 87 |
+
|
| 88 |
+
return epoch, loss, metrics
|
| 89 |
+
|
| 90 |
+
# To use the function, you would do something like this:
|
| 91 |
+
# epoch, loss, metrics = load_checkpoint(model, optimizer, 'model_checkpoint.pth')
|