Spaces:
Running
Running
| from collections import namedtuple | |
| import numpy as np | |
| import pytest | |
| import torch | |
| import treetensor.torch as ttorch | |
| from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \ | |
| list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \ | |
| one_time_warning, split_data_generator, get_shape0 | |
| class TestDefaultHelper(): | |
| def test_get_shape0(self): | |
| a = { | |
| 'a': { | |
| 'b': torch.randn(4, 3) | |
| }, | |
| 'c': { | |
| 'd': torch.randn(4) | |
| }, | |
| } | |
| b = [a, a] | |
| c = (a, a) | |
| d = { | |
| 'a': { | |
| 'b': ["a", "b", "c", "d"] | |
| }, | |
| 'c': { | |
| 'd': torch.randn(4) | |
| }, | |
| } | |
| a = ttorch.as_tensor(a) | |
| assert get_shape0(a) == 4 | |
| assert get_shape0(b) == 4 | |
| assert get_shape0(c) == 4 | |
| with pytest.raises(Exception) as e_info: | |
| assert get_shape0(d) == 4 | |
| def test_lists_to_dicts(self): | |
| set_pkg_seed(12) | |
| with pytest.raises(ValueError): | |
| lists_to_dicts([]) | |
| with pytest.raises(TypeError): | |
| lists_to_dicts([1]) | |
| assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]} | |
| T = namedtuple('T', ['location', 'race']) | |
| data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)] | |
| output = lists_to_dicts(data) | |
| assert isinstance(output, T) and output.__class__ == T | |
| assert len(output.location) == 3 | |
| data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)] | |
| output = lists_to_dicts(data, recursive=True) | |
| assert isinstance(output, dict) | |
| assert len(output['value']) == 3 | |
| assert len(output['obs']['scalar']) == 3 | |
| def test_dicts_to_lists(self): | |
| assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}] | |
| def test_squeeze(self): | |
| assert squeeze((4, )) == 4 | |
| assert squeeze({'a': 4}) == 4 | |
| assert squeeze([1, 3]) == (1, 3) | |
| data = np.random.randn(3) | |
| output = squeeze(data) | |
| assert (output == data).all() | |
| def test_default_get(self): | |
| assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1 | |
| assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1 | |
| with pytest.raises(AssertionError): | |
| default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0) | |
| assert default_get({'val': 1}, 'val', default_value=2) == 1 | |
| def test_override(self): | |
| class foo(object): | |
| def fun(self): | |
| raise NotImplementedError | |
| class foo1(foo): | |
| def fun(self): | |
| return "a" | |
| with pytest.raises(NameError): | |
| class foo2(foo): | |
| def func(self): | |
| pass | |
| with pytest.raises(NotImplementedError): | |
| foo().fun() | |
| foo1().fun() | |
| def test_error_wrapper(self): | |
| def good_ret(a, b=1): | |
| return a + b | |
| wrap_good_ret = error_wrapper(good_ret, 0) | |
| assert good_ret(1) == wrap_good_ret(1) | |
| def bad_ret(a, b=0): | |
| return a / b | |
| wrap_bad_ret = error_wrapper(bad_ret, 0) | |
| assert wrap_bad_ret(1) == 0 | |
| wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information') | |
| def test_list_split(self): | |
| data = [i for i in range(10)] | |
| output, residual = list_split(data, step=4) | |
| assert len(output) == 2 | |
| assert output[1] == [4, 5, 6, 7] | |
| assert residual == [8, 9] | |
| output, residual = list_split(data, step=5) | |
| assert len(output) == 2 | |
| assert output[1] == [5, 6, 7, 8, 9] | |
| assert residual is None | |
| class TestLimitedSpaceContainer(): | |
| def test_container(self): | |
| container = LimitedSpaceContainer(0, 5) | |
| first = container.acquire_space() | |
| assert first | |
| assert container.cur == 1 | |
| left = container.get_residual_space() | |
| assert left == 4 | |
| assert container.cur == container.max_val == 5 | |
| no_space = container.acquire_space() | |
| assert not no_space | |
| container.increase_space() | |
| six = container.acquire_space() | |
| assert six | |
| for i in range(6): | |
| container.release_space() | |
| assert container.cur == 5 - i | |
| container.decrease_space() | |
| assert container.max_val == 5 | |
| class TestDict: | |
| def test_deep_merge_dicts(self): | |
| dict1 = { | |
| 'a': 3, | |
| 'b': { | |
| 'c': 3, | |
| 'd': { | |
| 'e': 6, | |
| 'f': 5, | |
| } | |
| } | |
| } | |
| dict2 = { | |
| 'b': { | |
| 'c': 5, | |
| 'd': 6, | |
| 'g': 4, | |
| } | |
| } | |
| new_dict = deep_merge_dicts(dict1, dict2) | |
| assert new_dict['a'] == 3 | |
| assert isinstance(new_dict['b'], dict) | |
| assert new_dict['b']['c'] == 5 | |
| assert new_dict['b']['c'] == 5 | |
| assert new_dict['b']['g'] == 4 | |
| def test_deep_update(self): | |
| dict1 = { | |
| 'a': 3, | |
| 'b': { | |
| 'c': 3, | |
| 'd': { | |
| 'e': 6, | |
| 'f': 5, | |
| }, | |
| 'z': 4, | |
| } | |
| } | |
| dict2 = { | |
| 'b': { | |
| 'c': 5, | |
| 'd': 6, | |
| 'g': 4, | |
| } | |
| } | |
| with pytest.raises(RuntimeError): | |
| new1 = deep_update(dict1, dict2, new_keys_allowed=False) | |
| new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b']) | |
| assert new2['a'] == 3 | |
| assert new2['b']['c'] == 5 | |
| assert new2['b']['d'] == 6 | |
| assert new2['b']['g'] == 4 | |
| assert new2['b']['z'] == 4 | |
| dict1 = { | |
| 'a': 3, | |
| 'b': { | |
| 'type': 'old', | |
| 'z': 4, | |
| } | |
| } | |
| dict2 = { | |
| 'b': { | |
| 'type': 'new', | |
| 'c': 5, | |
| } | |
| } | |
| new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b']) | |
| assert new3['a'] == 3 | |
| assert new3['b']['type'] == 'new' | |
| assert new3['b']['c'] == 5 | |
| assert 'z' not in new3['b'] | |
| def test_flatten_dict(self): | |
| dict = { | |
| 'a': 3, | |
| 'b': { | |
| 'c': 3, | |
| 'd': { | |
| 'e': 6, | |
| 'f': 5, | |
| }, | |
| 'z': 4, | |
| } | |
| } | |
| flat = flatten_dict(dict) | |
| assert flat['a'] == 3 | |
| assert flat['b/c'] == 3 | |
| assert flat['b/d/e'] == 6 | |
| assert flat['b/d/f'] == 5 | |
| assert flat['b/z'] == 4 | |
| def test_one_time_warning(self): | |
| one_time_warning('test_one_time_warning') | |
| def test_running_mean_std(self): | |
| running = RunningMeanStd() | |
| running.reset() | |
| running.update(np.arange(1, 10)) | |
| assert running.mean == pytest.approx(5, abs=1e-4) | |
| assert running.std == pytest.approx(2.582030, abs=1e-6) | |
| running.update(np.arange(2, 11)) | |
| assert running.mean == pytest.approx(5.5, abs=1e-4) | |
| assert running.std == pytest.approx(2.629981, abs=1e-6) | |
| running.reset() | |
| running.update(np.arange(1, 10)) | |
| assert pytest.approx(running.mean, abs=1e-4) == 5 | |
| assert running.mean == pytest.approx(5, abs=1e-4) | |
| assert running.std == pytest.approx(2.582030, abs=1e-6) | |
| new_shape = running.new_shape((2, 4), (3, ), (1, )) | |
| assert isinstance(new_shape, tuple) and len(new_shape) == 3 | |
| running = RunningMeanStd(shape=(4, )) | |
| running.reset() | |
| running.update(np.random.random((10, 4))) | |
| assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, ) | |
| assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, ) | |
| def test_split_data_generator(self): | |
| def get_data(): | |
| return { | |
| 'obs': torch.randn(5), | |
| 'action': torch.randint(0, 10, size=(1, )), | |
| 'prev_state': [None, None], | |
| 'info': { | |
| 'other_obs': torch.randn(5) | |
| }, | |
| } | |
| data = [get_data() for _ in range(4)] | |
| data = lists_to_dicts(data) | |
| data['obs'] = torch.stack(data['obs']) | |
| data['action'] = torch.stack(data['action']) | |
| data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])} | |
| assert len(data['obs']) == 4 | |
| data['NoneKey'] = None | |
| generator = split_data_generator(data, 3) | |
| generator_result = list(generator) | |
| assert len(generator_result) == 2 | |
| assert generator_result[0]['NoneKey'] is None | |
| assert len(generator_result[0]['obs']) == 3 | |
| assert generator_result[0]['info']['other_obs'].shape == (3, 5) | |
| assert generator_result[1]['NoneKey'] is None | |
| assert len(generator_result[1]['obs']) == 3 | |
| assert generator_result[1]['info']['other_obs'].shape == (3, 5) | |
| generator = split_data_generator(data, 3, shuffle=False) | |