Spaces:
Running
Running
| import os | |
| import time | |
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| from typing import Any | |
| from functools import partial | |
| from ding.worker import BaseLearner | |
| from ding.worker.learner import LearnerHook, add_learner_hook, create_learner | |
| class FakeLearner(BaseLearner): | |
| def random_data(): | |
| return { | |
| 'obs': torch.randn(2), | |
| 'replay_buffer_idx': 0, | |
| 'replay_unique_id': 0, | |
| } | |
| def get_data(self, batch_size): | |
| return [self.random_data for _ in range(batch_size)] | |
| class FakePolicy: | |
| def __init__(self): | |
| self._model = torch.nn.Identity() | |
| def forward(self, x): | |
| return { | |
| 'total_loss': torch.randn(1).squeeze(), | |
| 'cur_lr': 0.1, | |
| 'priority': [1., 2., 3.], | |
| '[histogram]h_example': [1.2, 2.3, 3.4], | |
| '[scalars]s_example': { | |
| 'a': 5., | |
| 'b': 4. | |
| }, | |
| } | |
| def data_preprocess(self, x): | |
| return x | |
| def state_dict(self): | |
| return {'model': self._model} | |
| def load_state_dict(self, state_dict): | |
| pass | |
| def info(self): | |
| return 'FakePolicy' | |
| def monitor_vars(self): | |
| return ['total_loss', 'cur_lr'] | |
| def get_attribute(self, name): | |
| if name == 'cuda': | |
| return False | |
| elif name == 'device': | |
| return 'cpu' | |
| elif name == 'batch_size': | |
| return 2 | |
| elif name == 'on_policy': | |
| return False | |
| else: | |
| raise KeyError | |
| def reset(self): | |
| pass | |
| class TestBaseLearner: | |
| def _get_cfg(self, path): | |
| cfg = BaseLearner.default_config() | |
| cfg.import_names = [] | |
| cfg.learner_type = 'fake' | |
| cfg.train_iterations = 10 | |
| cfg.hook.load_ckpt_before_run = path | |
| cfg.hook.log_show_after_iter = 5 | |
| # Another way to build hook: Complete config | |
| cfg.hook.save_ckpt_after_iter = dict( | |
| name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5} | |
| ) | |
| return cfg | |
| def test_naive(self): | |
| os.popen('rm -rf iteration_5.pth.tar*') | |
| time.sleep(1.0) | |
| with pytest.raises(KeyError): | |
| create_learner(EasyDict({'type': 'placeholder', 'import_names': []})) | |
| path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar') | |
| torch.save({'model': {}, 'last_iter': 5}, path) | |
| time.sleep(0.5) | |
| cfg = self._get_cfg(path) | |
| learner = FakeLearner(cfg, exp_name='exp_test') | |
| learner.policy = FakePolicy() | |
| learner.setup_dataloader() | |
| learner.start() | |
| time.sleep(2) | |
| assert learner.last_iter.val == 10 + 5 | |
| # test hook | |
| dir_name = '{}/ckpt'.format(learner.exp_name) | |
| for n in [5, 10, 15]: | |
| assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n)) | |
| for n in [0, 4, 7, 12]: | |
| assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n)) | |
| learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.') | |
| learner.save_checkpoint('best') | |
| info = learner.learn_info | |
| for info_name in ['learner_step', 'priority_info', 'learner_done']: | |
| assert info_name in info | |
| class FakeHook(LearnerHook): | |
| def __call__(self, engine: Any) -> Any: | |
| pass | |
| original_hook_num = len(learner._hooks['after_run']) | |
| add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run')) | |
| assert len(learner._hooks['after_run']) == original_hook_num + 1 | |
| os.popen('rm -rf iteration_5.pth.tar*') | |
| os.popen('rm -rf ' + dir_name) | |
| os.popen('rm -rf learner') | |
| os.popen('rm -rf log') | |
| learner.close() | |