Spaces:
Running
Running
| import pytest | |
| from collections import namedtuple | |
| import random | |
| import numpy as np | |
| import torch | |
| from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate | |
| B, T = 4, 3 | |
| class TestTimestepCollate: | |
| def get_data(self): | |
| data = { | |
| 'obs': [torch.randn(4) for _ in range(T)], | |
| 'reward': [torch.FloatTensor([0]) for _ in range(T)], | |
| 'done': [False for _ in range(T)], | |
| 'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)], | |
| 'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], | |
| } | |
| return data | |
| def get_multi_shape_state_data(self): | |
| data = { | |
| 'obs': [torch.randn(4) for _ in range(T)], | |
| 'reward': [torch.FloatTensor([0]) for _ in range(T)], | |
| 'done': [False for _ in range(T)], | |
| 'prev_state': [ | |
| [(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))] | |
| for _ in range(T) | |
| ], | |
| 'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)], | |
| } | |
| return data | |
| def test(self): | |
| batch = timestep_collate([self.get_data() for _ in range(B)]) | |
| assert isinstance(batch, dict) | |
| assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) | |
| assert batch['obs'].shape == (T, B, 4) | |
| assert batch['reward'].shape == (T, B) | |
| assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool | |
| assert isinstance(batch['prev_state'], list) | |
| assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B | |
| assert isinstance(batch['action'], list) and len(batch['action']) == T | |
| assert batch['action'][0][0].shape == (B, 3) | |
| assert batch['action'][0][1].shape == (B, 5) | |
| # hidden_state might contain multi prev_states with different shapes | |
| batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)]) | |
| assert isinstance(batch, dict) | |
| assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action']) | |
| assert batch['obs'].shape == (T, B, 4) | |
| assert batch['reward'].shape == (T, B) | |
| assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool | |
| assert isinstance(batch['prev_state'], list) | |
| print(batch['prev_state'][0][0]) | |
| assert len(batch['prev_state']) == T and len(batch['prev_state'][0] | |
| ) == B and len(batch['prev_state'][0][0]) == 3 | |
| assert isinstance(batch['action'], list) and len(batch['action']) == T | |
| assert batch['action'][0][0].shape == (B, 3) | |
| assert batch['action'][0][1].shape == (B, 5) | |
| class TestDefaultCollate: | |
| def test_numpy(self): | |
| data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)] | |
| data = default_collate(data) | |
| assert data.shape == (5, 4, 3) | |
| assert data.dtype == torch.float64 | |
| data = [float(np.random.randn(1)[0]) for _ in range(6)] | |
| data = default_collate(data) | |
| assert data.shape == (6, ) | |
| assert data.dtype == torch.float32 | |
| with pytest.raises(TypeError): | |
| default_collate([np.array(['str']) for _ in range(3)]) | |
| def test_basic(self): | |
| data = [random.random() for _ in range(3)] | |
| data = default_collate(data) | |
| assert data.shape == (3, ) | |
| assert data.dtype == torch.float32 | |
| data = [random.randint(0, 10) for _ in range(3)] | |
| data = default_collate(data) | |
| assert data.shape == (3, ) | |
| assert data.dtype == torch.int64 | |
| data = ['str' for _ in range(4)] | |
| data = default_collate(data) | |
| assert len(data) == 4 | |
| assert all([s == 'str' for s in data]) | |
| T = namedtuple('T', ['x', 'y']) | |
| data = [T(1, 2) for _ in range(4)] | |
| data = default_collate(data) | |
| assert isinstance(data, T) | |
| assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4 | |
| assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4 | |
| with pytest.raises(TypeError): | |
| default_collate([object() for _ in range(4)]) | |
| data = [{'collate_ignore_data': random.random()} for _ in range(4)] | |
| data = default_collate(data) | |
| assert isinstance(data, dict) | |
| assert len(data['collate_ignore_data']) == 4 | |
| class TestDefaultDecollate: | |
| def test(self): | |
| with pytest.raises(TypeError): | |
| default_decollate([object() for _ in range(4)]) | |
| data = torch.randn(4, 3, 5) | |
| data = default_decollate(data) | |
| print([d.shape for d in data]) | |
| assert len(data) == 4 and all([d.shape == (3, 5) for d in data]) | |
| data = [torch.randn(8, 2, 4), torch.randn(8, 5)] | |
| data = default_decollate(data) | |
| assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data]) | |
| data = { | |
| 'logit': torch.randn(4, 13), | |
| 'action': torch.randint(0, 13, size=(4, )), | |
| 'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)], | |
| } | |
| data = default_decollate(data) | |
| assert len(data) == 4 and isinstance(data, list) | |
| assert all([d['logit'].shape == (13, ) for d in data]) | |
| assert all([d['action'].shape == (1, ) for d in data]) | |
| assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data]) | |
| class TestDiffShapeCollate: | |
| def test(self): | |
| with pytest.raises(TypeError): | |
| diff_shape_collate([object() for _ in range(4)]) | |
| data = [ | |
| { | |
| 'item1': torch.randn(4), | |
| 'item2': None, | |
| 'item3': torch.randn(3), | |
| 'item4': np.random.randn(5, 6) | |
| }, | |
| { | |
| 'item1': torch.randn(5), | |
| 'item2': torch.randn(6), | |
| 'item3': torch.randn(3), | |
| 'item4': np.random.randn(5, 6) | |
| }, | |
| ] | |
| data = diff_shape_collate(data) | |
| assert isinstance(data['item1'], list) and len(data['item1']) == 2 | |
| assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None | |
| assert data['item3'].shape == (2, 3) | |
| assert data['item4'].shape == (2, 5, 6) | |
| data = [ | |
| { | |
| 'item1': 1, | |
| 'item2': 3, | |
| 'item3': 2.0 | |
| }, | |
| { | |
| 'item1': None, | |
| 'item2': 4, | |
| 'item3': 2.0 | |
| }, | |
| ] | |
| data = diff_shape_collate(data) | |
| assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None | |
| assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64 | |
| assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32 | |