Spaces:
Sleeping
Sleeping
| from easydict import EasyDict | |
| import pytest | |
| from ding.utils import Scheduler | |
| from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config | |
| class TestSchedulerModule(): | |
| test_merged_scheduler_config = dict( | |
| schedule_flag=False, | |
| schedule_mode='reduce', | |
| factor=0.05, | |
| change_range=[-1, 1], | |
| threshold=1e-4, | |
| optimize_mode='min', | |
| patience=1, | |
| cooldown=0, | |
| ) | |
| test_merged_scheduler_config = EasyDict(test_merged_scheduler_config) | |
| test_policy_config = EasyDict(league_demo_ppo_config.policy) | |
| test_policy_config_param = test_policy_config.learn.entropy_weight | |
| def test_init_factor(self): | |
| self.test_merged_scheduler_config.factor = 'hello_test' | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'float/int' in str(excinfo.value) | |
| self.test_merged_scheduler_config.factor = 0 | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'greater than 0' in str(excinfo.value) | |
| # recover the correct value for later test function | |
| self.test_merged_scheduler_config.factor = 0.05 | |
| def test_init_change_range(self): | |
| self.test_merged_scheduler_config.change_range = 0 | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'list' in str(excinfo.value) | |
| self.test_merged_scheduler_config.change_range = [0, 'hello_test'] | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'float' in str(excinfo.value) | |
| self.test_merged_scheduler_config.change_range = [0, -1] | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'smaller' in str(excinfo.value) | |
| # recover the correct value for later test function | |
| self.test_merged_scheduler_config.change_range = [-1, 1] | |
| def test_init_patience(self): | |
| self.test_merged_scheduler_config.patience = "hello_test" | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'integer' in str(excinfo.value) | |
| self.test_merged_scheduler_config.patience = -1 | |
| with pytest.raises(AssertionError) as excinfo: | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert 'greater' in str(excinfo.value) | |
| # recover the correct value for later test function | |
| self.test_merged_scheduler_config.patience = 1 | |
| def test_is_better(self): | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert test_scheduler.is_better(-1) is True | |
| test_scheduler.last_metrics = 1 | |
| assert test_scheduler.is_better(0.5) is True | |
| def test_in_cooldown(self): | |
| self.test_merged_scheduler_config.cooldown_counter = 0 | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert test_scheduler.in_cooldown is False | |
| def test_step(self): | |
| self.test_merged_scheduler_config.cooldown = 1 | |
| test_scheduler = Scheduler(self.test_merged_scheduler_config) | |
| assert test_scheduler.cooldown_counter == 1 | |
| test_scheduler.last_metrics = 1.0 | |
| old_param = self.test_policy_config.learn.entropy_weight | |
| # good epoch with maximum cooldown lenth is 1 | |
| self.test_policy_config_param = test_scheduler.step(0.9, self.test_policy_config_param) | |
| assert self.test_policy_config_param == old_param | |
| assert test_scheduler.cooldown_counter == 0 | |
| assert test_scheduler.last_metrics == 0.9 | |
| assert test_scheduler.bad_epochs_num == 0 | |
| # first bad epoch in cooldown period | |
| self.test_policy_config_param = test_scheduler.step(0.899999, self.test_policy_config_param) | |
| assert self.test_policy_config_param == old_param | |
| assert test_scheduler.cooldown_counter == 0 | |
| assert test_scheduler.last_metrics == 0.899999 | |
| assert test_scheduler.bad_epochs_num == 1 | |
| # first bad epoch after cooldown | |
| self.test_policy_config_param = test_scheduler.step(0.899998, self.test_policy_config_param) | |
| assert self.test_policy_config_param == old_param - self.test_merged_scheduler_config.factor | |
| assert test_scheduler.cooldown_counter == 1 | |
| assert test_scheduler.last_metrics == 0.899998 | |
| assert test_scheduler.bad_epochs_num == 0 | |