| from typing import Optional |
| import os |
| import pathlib |
| import hydra |
| import copy |
| from hydra.core.hydra_config import HydraConfig |
| from omegaconf import OmegaConf |
| import dill |
| import torch |
| import threading |
|
|
|
|
| class BaseWorkspace: |
| include_keys = tuple() |
| exclude_keys = tuple() |
|
|
| def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None): |
| self.cfg = cfg |
| self._output_dir = output_dir |
| self._saving_thread = None |
|
|
| @property |
| def output_dir(self): |
| output_dir = self._output_dir |
| if output_dir is None: |
| output_dir = HydraConfig.get().runtime.output_dir |
| return output_dir |
| |
| def run(self): |
| """ |
| Create any resource shouldn't be serialized as local variables |
| """ |
| pass |
|
|
| def save_checkpoint(self, path=None, tag='latest', |
| exclude_keys=None, |
| include_keys=None, |
| use_thread=True): |
| if path is None: |
| path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') |
| else: |
| path = pathlib.Path(path) |
| if exclude_keys is None: |
| exclude_keys = tuple(self.exclude_keys) |
| if include_keys is None: |
| include_keys = tuple(self.include_keys) + ('_output_dir',) |
|
|
| path.parent.mkdir(parents=False, exist_ok=True) |
| payload = { |
| 'cfg': self.cfg, |
| 'state_dicts': dict(), |
| 'pickles': dict() |
| } |
|
|
| for key, value in self.__dict__.items(): |
| if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'): |
| |
| if key not in exclude_keys: |
| if use_thread: |
| payload['state_dicts'][key] = _copy_to_cpu(value.state_dict()) |
| else: |
| payload['state_dicts'][key] = value.state_dict() |
| elif key in include_keys: |
| payload['pickles'][key] = dill.dumps(value) |
| if use_thread: |
| self._saving_thread = threading.Thread( |
| target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill)) |
| self._saving_thread.start() |
| else: |
| torch.save(payload, path.open('wb'), pickle_module=dill) |
| return str(path.absolute()) |
| |
| def get_checkpoint_path(self, tag='latest'): |
| return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') |
|
|
| def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs): |
| if exclude_keys is None: |
| exclude_keys = tuple() |
| if include_keys is None: |
| include_keys = payload['pickles'].keys() |
|
|
| for key, value in payload['state_dicts'].items(): |
| if key not in exclude_keys: |
| self.__dict__[key].load_state_dict(value, **kwargs) |
| for key in include_keys: |
| if key in payload['pickles']: |
| self.__dict__[key] = dill.loads(payload['pickles'][key]) |
| |
| def load_checkpoint(self, path=None, tag='latest', |
| exclude_keys=None, |
| include_keys=None, |
| **kwargs): |
| if path is None: |
| path = self.get_checkpoint_path(tag=tag) |
| else: |
| path = pathlib.Path(path) |
| payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs) |
| self.load_payload(payload, |
| exclude_keys=exclude_keys, |
| include_keys=include_keys) |
| return payload |
| |
| @classmethod |
| def create_from_checkpoint(cls, path, |
| exclude_keys=None, |
| include_keys=None, |
| **kwargs): |
| payload = torch.load(open(path, 'rb'), pickle_module=dill) |
| instance = cls(payload['cfg']) |
| instance.load_payload( |
| payload=payload, |
| exclude_keys=exclude_keys, |
| include_keys=include_keys, |
| **kwargs) |
| return instance |
|
|
| def save_snapshot(self, tag='latest'): |
| """ |
| Quick loading and saving for reserach, saves full state of the workspace. |
| |
| However, loading a snapshot assumes the code stays exactly the same. |
| Use save_checkpoint for long-term storage. |
| """ |
| path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl') |
| path.parent.mkdir(parents=False, exist_ok=True) |
| torch.save(self, path.open('wb'), pickle_module=dill) |
| return str(path.absolute()) |
| |
| @classmethod |
| def create_from_snapshot(cls, path): |
| return torch.load(open(path, 'rb'), pickle_module=dill) |
|
|
|
|
| def _copy_to_cpu(x): |
| if isinstance(x, torch.Tensor): |
| return x.detach().to('cpu') |
| elif isinstance(x, dict): |
| result = dict() |
| for k, v in x.items(): |
| result[k] = _copy_to_cpu(v) |
| return result |
| elif isinstance(x, list): |
| return [_copy_to_cpu(k) for k in x] |
| else: |
| return copy.deepcopy(x) |
|
|