File size: 6,920 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * Copyright 2020 Ross Wightman
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
""" Checkpoint Saver
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
Hacked together by / Copyright 2020 Ross Wightman
"""
import glob
import operator
import os
import torch
from timm.utils.model import unwrap_model, get_state_dict
from common.utils import LOGGER
class CheckpointSaver: # don't save optimizer state dict, since it forces specific repo sturcture
def __init__(
self,
model,
optimizer,
args=None,
model_ema=None,
amp_scaler=None,
checkpoint_prefix='checkpoint',
recovery_prefix='recovery',
checkpoint_dir='',
recovery_dir='',
decreasing=False,
max_history=10,
unwrap_fn=unwrap_model):
# objects to save state_dicts of
self.model = model
self.optimizer = optimizer
self.args = args
self.model_ema = model_ema
self.amp_scaler = amp_scaler
# state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
self.best_epoch = None
self.best_metric = None
self.curr_recovery_file = ''
self.last_recovery_file = ''
# config
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.save_prefix = checkpoint_prefix
self.recovery_prefix = recovery_prefix
self.extension = '.pth.tar'
self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.max_history = max_history
self.unwrap_fn = unwrap_fn
assert self.max_history >= 1
def save_checkpoint(self, epoch, metric=None, metric_ema=None):
assert epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, epoch, metric, metric_ema)
if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path)
worst_file = self.checkpoint_files[-1][1] if self.checkpoint_files else None
if (len(self.checkpoint_files) < self.max_history
or metric is None or self.cmp(metric, worst_file)):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename)
os.link(last_save_path, save_path)
self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1],
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files:
checkpoints_str += ' {}\n'.format(c)
LOGGER.info(checkpoints_str)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch
self.best_metric = metric
best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
if os.path.exists(best_save_path):
os.unlink(best_save_path)
os.link(last_save_path, best_save_path)
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, epoch, metric=None, metric_ema=None):
save_state = {
'epoch': epoch,
'arch': type(self.model).__name__.lower(),
'state_dict': get_state_dict(self.model, self.unwrap_fn),
# 'optimizer': self.optimizer.state_dict(),
'version': 2, # version < 2 incrementsf epoch before save
}
if self.args is not None:
save_state['arch'] = self.args.model_name
save_state['args'] = self.args
# if self.amp_scaler is not None:
# save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
if metric_ema > metric: # save EMA weights instead of regular weights
save_state['state_dict'] = get_state_dict(self.model_ema, self.unwrap_fn)
# save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
return
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
LOGGER.debug("Cleaning checkpoint: %s", d)
os.remove(d[0])
except Exception as e:
LOGGER.error("Exception '%s' while deleting checkpoint", e)
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, epoch, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, epoch)
if os.path.exists(self.last_recovery_file):
try:
LOGGER.debug("Cleaning recovery: %s", self.last_recovery_file)
os.remove(self.last_recovery_file)
except Exception as e:
LOGGER.error("Exception '%s' while removing %s", e, self.last_recovery_file)
self.last_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
def find_recovery(self):
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files)
return files[0] if len(files) else '' |