| | from typing import Optional |
| | import os |
| | import pathlib |
| | import hydra |
| | import copy |
| | from hydra.core.hydra_config import HydraConfig |
| | from omegaconf import OmegaConf |
| | import dill |
| | import torch |
| | import threading |
| |
|
| |
|
| | class BaseWorkspace: |
| | include_keys = tuple() |
| | exclude_keys = tuple() |
| |
|
| | def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None): |
| | self.cfg = cfg |
| | self._output_dir = output_dir |
| | self._saving_thread = None |
| |
|
| | @property |
| | def output_dir(self): |
| | output_dir = self._output_dir |
| | if output_dir is None: |
| | output_dir = HydraConfig.get().runtime.output_dir |
| | return output_dir |
| | |
| | def run(self): |
| | """ |
| | Create any resource shouldn't be serialized as local variables |
| | """ |
| | pass |
| |
|
| | def save_checkpoint(self, path=None, tag='latest', |
| | exclude_keys=None, |
| | include_keys=None, |
| | use_thread=True): |
| | if path is None: |
| | path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') |
| | else: |
| | path = pathlib.Path(path) |
| | if exclude_keys is None: |
| | exclude_keys = tuple(self.exclude_keys) |
| | if include_keys is None: |
| | include_keys = tuple(self.include_keys) + ('_output_dir',) |
| |
|
| | path.parent.mkdir(parents=False, exist_ok=True) |
| | payload = { |
| | 'cfg': self.cfg, |
| | 'state_dicts': dict(), |
| | 'pickles': dict() |
| | } |
| |
|
| | for key, value in self.__dict__.items(): |
| | if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'): |
| | |
| | if key not in exclude_keys: |
| | if use_thread: |
| | payload['state_dicts'][key] = _copy_to_cpu(value.state_dict()) |
| | else: |
| | payload['state_dicts'][key] = value.state_dict() |
| | elif key in include_keys: |
| | payload['pickles'][key] = dill.dumps(value) |
| | if use_thread: |
| | self._saving_thread = threading.Thread( |
| | target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill)) |
| | self._saving_thread.start() |
| | else: |
| | torch.save(payload, path.open('wb'), pickle_module=dill) |
| | return str(path.absolute()) |
| | |
| | def get_checkpoint_path(self, tag='latest'): |
| | return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') |
| |
|
| | def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs): |
| | if exclude_keys is None: |
| | exclude_keys = tuple() |
| | if include_keys is None: |
| | include_keys = payload['pickles'].keys() |
| |
|
| | for key, value in payload['state_dicts'].items(): |
| | if key not in exclude_keys: |
| | self.__dict__[key].load_state_dict(value, **kwargs) |
| | for key in include_keys: |
| | if key in payload['pickles']: |
| | self.__dict__[key] = dill.loads(payload['pickles'][key]) |
| | |
| | def load_checkpoint(self, path=None, tag='latest', |
| | exclude_keys=None, |
| | include_keys=None, |
| | **kwargs): |
| | if path is None: |
| | path = self.get_checkpoint_path(tag=tag) |
| | else: |
| | path = pathlib.Path(path) |
| | payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs) |
| | self.load_payload(payload, |
| | exclude_keys=exclude_keys, |
| | include_keys=include_keys) |
| | return payload |
| | |
| | @classmethod |
| | def create_from_checkpoint(cls, path, |
| | exclude_keys=None, |
| | include_keys=None, |
| | **kwargs): |
| | payload = torch.load(open(path, 'rb'), pickle_module=dill) |
| | instance = cls(payload['cfg']) |
| | instance.load_payload( |
| | payload=payload, |
| | exclude_keys=exclude_keys, |
| | include_keys=include_keys, |
| | **kwargs) |
| | return instance |
| |
|
| | def save_snapshot(self, tag='latest'): |
| | """ |
| | Quick loading and saving for reserach, saves full state of the workspace. |
| | |
| | However, loading a snapshot assumes the code stays exactly the same. |
| | Use save_checkpoint for long-term storage. |
| | """ |
| | path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl') |
| | path.parent.mkdir(parents=False, exist_ok=True) |
| | torch.save(self, path.open('wb'), pickle_module=dill) |
| | return str(path.absolute()) |
| | |
| | @classmethod |
| | def create_from_snapshot(cls, path): |
| | return torch.load(open(path, 'rb'), pickle_module=dill) |
| |
|
| |
|
| | def _copy_to_cpu(x): |
| | if isinstance(x, torch.Tensor): |
| | return x.detach().to('cpu') |
| | elif isinstance(x, dict): |
| | result = dict() |
| | for k, v in x.items(): |
| | result[k] = _copy_to_cpu(v) |
| | return result |
| | elif isinstance(x, list): |
| | return [_copy_to_cpu(k) for k in x] |
| | else: |
| | return copy.deepcopy(x) |
| |
|