| | from __future__ import division |
| | import os |
| | import torch |
| | import datetime |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class CheckpointSaver(): |
| | """Class that handles saving and loading checkpoints during training.""" |
| | def __init__(self, save_dir, save_steps=1000, overwrite=False): |
| | self.save_dir = os.path.abspath(save_dir) |
| | self.save_steps = save_steps |
| | self.overwrite = overwrite |
| | if not os.path.exists(self.save_dir): |
| | os.makedirs(self.save_dir) |
| | self.get_latest_checkpoint() |
| | return |
| |
|
| | def exists_checkpoint(self, checkpoint_file=None): |
| | """Check if a checkpoint exists in the current directory.""" |
| | if checkpoint_file is None: |
| | return False if self.latest_checkpoint is None else True |
| | else: |
| | return os.path.isfile(checkpoint_file) |
| |
|
| | def save_checkpoint( |
| | self, |
| | models, |
| | optimizers, |
| | epoch, |
| | batch_idx, |
| | batch_size, |
| | total_step_count, |
| | is_best=False, |
| | save_by_step=False, |
| | interval=5, |
| | with_optimizer=True |
| | ): |
| | """Save checkpoint.""" |
| | timestamp = datetime.datetime.now() |
| | if self.overwrite: |
| | checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) |
| | elif save_by_step: |
| | checkpoint_filename = os.path.abspath( |
| | os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)) |
| | ) |
| | else: |
| | if epoch % interval == 0: |
| | checkpoint_filename = os.path.abspath( |
| | os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt') |
| | ) |
| | else: |
| | checkpoint_filename = None |
| |
|
| | checkpoint = {} |
| | for model in models: |
| | model_dict = models[model].state_dict() |
| | for k in list(model_dict.keys()): |
| | if '.smpl.' in k: |
| | del model_dict[k] |
| | checkpoint[model] = model_dict |
| | if with_optimizer: |
| | for optimizer in optimizers: |
| | checkpoint[optimizer] = optimizers[optimizer].state_dict() |
| | checkpoint['epoch'] = epoch |
| | checkpoint['batch_idx'] = batch_idx |
| | checkpoint['batch_size'] = batch_size |
| | checkpoint['total_step_count'] = total_step_count |
| | print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) |
| |
|
| | if checkpoint_filename is not None: |
| | torch.save(checkpoint, checkpoint_filename) |
| | print('Saving checkpoint file [' + checkpoint_filename + ']') |
| | if is_best: |
| | checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) |
| | torch.save(checkpoint, checkpoint_filename) |
| | print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) |
| | print('Saving checkpoint file [' + checkpoint_filename + ']') |
| | torch.save(checkpoint, checkpoint_filename) |
| | print('Saved checkpoint file [' + checkpoint_filename + ']') |
| |
|
| | def load_checkpoint(self, models, optimizers, checkpoint_file=None): |
| | """Load a checkpoint.""" |
| | if checkpoint_file is None: |
| | logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']') |
| | checkpoint_file = self.latest_checkpoint |
| | checkpoint = torch.load(checkpoint_file) |
| | for model in models: |
| | if model in checkpoint: |
| | model_dict = models[model].state_dict() |
| | pretrained_dict = { |
| | k: v |
| | for k, v in checkpoint[model].items() if k in model_dict.keys() |
| | } |
| | model_dict.update(pretrained_dict) |
| | models[model].load_state_dict(model_dict) |
| |
|
| | |
| | for optimizer in optimizers: |
| | if optimizer in checkpoint: |
| | optimizers[optimizer].load_state_dict(checkpoint[optimizer]) |
| | return { |
| | 'epoch': checkpoint['epoch'], |
| | 'batch_idx': checkpoint['batch_idx'], |
| | 'batch_size': checkpoint['batch_size'], |
| | 'total_step_count': checkpoint['total_step_count'] |
| | } |
| |
|
| | def get_latest_checkpoint(self): |
| | """Get filename of latest checkpoint if it exists.""" |
| | checkpoint_list = [] |
| | for dirpath, dirnames, filenames in os.walk(self.save_dir): |
| | for filename in filenames: |
| | if filename.endswith('.pt'): |
| | checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) |
| | |
| | import re |
| |
|
| | def atof(text): |
| | try: |
| | retval = float(text) |
| | except ValueError: |
| | retval = text |
| | return retval |
| |
|
| | def natural_keys(text): |
| | ''' |
| | alist.sort(key=natural_keys) sorts in human order |
| | http://nedbatchelder.com/blog/200712/human_sorting.html |
| | (See Toothy's implementation in the comments) |
| | float regex comes from https://stackoverflow.com/a/12643073/190597 |
| | ''' |
| | return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)] |
| |
|
| | checkpoint_list.sort(key=natural_keys) |
| | self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] |
| | return |
| |
|