| | |
| | import os.path as osp |
| | import time |
| | from tempfile import TemporaryDirectory |
| |
|
| | import torch |
| | from torch.optim import Optimizer |
| |
|
| | import mmcv |
| | from mmcv.parallel import is_module_wrapper |
| | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict |
| |
|
| | try: |
| | import apex |
| | except: |
| | print('apex is not installed') |
| |
|
| |
|
| | def save_checkpoint(model, filename, optimizer=None, meta=None): |
| | """Save checkpoint to file. |
| | |
| | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and |
| | ``optimizer``, ``amp``. By default ``meta`` will contain version |
| | and time info. |
| | |
| | Args: |
| | model (Module): Module whose params are to be saved. |
| | filename (str): Checkpoint filename. |
| | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. |
| | meta (dict, optional): Metadata to be saved in checkpoint. |
| | """ |
| | if meta is None: |
| | meta = {} |
| | elif not isinstance(meta, dict): |
| | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') |
| | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) |
| |
|
| | if is_module_wrapper(model): |
| | model = model.module |
| |
|
| | if hasattr(model, 'CLASSES') and model.CLASSES is not None: |
| | |
| | meta.update(CLASSES=model.CLASSES) |
| |
|
| | checkpoint = { |
| | 'meta': meta, |
| | 'state_dict': weights_to_cpu(get_state_dict(model)) |
| | } |
| | |
| | if isinstance(optimizer, Optimizer): |
| | checkpoint['optimizer'] = optimizer.state_dict() |
| | elif isinstance(optimizer, dict): |
| | checkpoint['optimizer'] = {} |
| | for name, optim in optimizer.items(): |
| | checkpoint['optimizer'][name] = optim.state_dict() |
| |
|
| | |
| | |
| |
|
| | if filename.startswith('pavi://'): |
| | try: |
| | from pavi import modelcloud |
| | from pavi.exception import NodeNotFoundError |
| | except ImportError: |
| | raise ImportError( |
| | 'Please install pavi to load checkpoint from modelcloud.') |
| | model_path = filename[7:] |
| | root = modelcloud.Folder() |
| | model_dir, model_name = osp.split(model_path) |
| | try: |
| | model = modelcloud.get(model_dir) |
| | except NodeNotFoundError: |
| | model = root.create_training_model(model_dir) |
| | with TemporaryDirectory() as tmp_dir: |
| | checkpoint_file = osp.join(tmp_dir, model_name) |
| | with open(checkpoint_file, 'wb') as f: |
| | torch.save(checkpoint, f) |
| | f.flush() |
| | model.create_file(checkpoint_file, name=model_name) |
| | else: |
| | mmcv.mkdir_or_exist(osp.dirname(filename)) |
| | |
| | with open(filename, 'wb') as f: |
| | torch.save(checkpoint, f) |
| | f.flush() |
| |
|