Spaces:
Running
Running
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| import os | |
| from ding.utils.data import offline_data_save_type, create_dataset, NaiveRLDataset, D4RLDataset, HDF5Dataset | |
| cfg1 = dict(policy=dict(collect=dict( | |
| data_type='naive', | |
| data_path='./expert.pkl', | |
| ), )) | |
| cfg2 = dict( | |
| env=dict(norm_obs=dict(use_norm=True, offline_stats=dict(use_offline_stats=True))), | |
| policy=dict(collect=dict(data_type='hdf5', data_path='./expert_demos.hdf5')), | |
| ) | |
| cfg3 = dict(env=dict(env_id='hopper-expert-v0'), policy=dict(collect=dict(data_type='d4rl', ), )) | |
| cfgs = [cfg1, cfg2] # cfg3 | |
| unittest_args = ['naive', 'hdf5'] | |
| # fake transition & data | |
| transition = {} | |
| transition['obs'] = torch.zeros((3, 1)) | |
| transition['next_obs'] = torch.zeros((3, 1)) | |
| transition['action'] = torch.zeros((1, 1)) | |
| transition['reward'] = torch.tensor((1, )) | |
| transition['done'] = False | |
| transition['collect_iter'] = 0 | |
| fake_data = [transition for i in range(32)] | |
| expert_data_path = './expert.pkl' | |
| def test_offline_data_save_type(data_type): | |
| offline_data_save_type(exp_data=fake_data, expert_data_path=expert_data_path, data_type=data_type) | |
| def test_dataset(cfg): | |
| cfg = EasyDict(cfg) | |
| create_dataset(cfg) | |
| def test_NaiveRLDataset(cfg): | |
| cfg = EasyDict(cfg) | |
| NaiveRLDataset(cfg) | |
| dataset = NaiveRLDataset(expert_data_path) | |
| assert type(len(dataset)) == int | |
| assert dataset[0] is not None | |
| # @pytest.mark.parametrize('cfg', [cfg3]) | |
| # @pytest.mark.unittest | |
| # def test_D4RLDataset(cfg): | |
| # cfg = EasyDict(cfg) | |
| # dataset = D4RLDataset(cfg) | |
| def test_HDF5Dataset(cfg): | |
| cfg = EasyDict(cfg) | |
| dataset = HDF5Dataset(cfg) | |
| assert dataset.mean is not None and dataset.std[0] is not None | |
| assert dataset._data['obs'].mean(0)[0] == 0 | |
| assert type(len(dataset)) == int | |
| assert dataset[0] is not None | |
| def cleanup(request): | |
| def remove_test_dir(): | |
| if os.path.exists('./expert.pkl'): | |
| os.remove('./expert.pkl') | |
| if os.path.exists('./expert_demos.hdf5'): | |
| os.remove('./expert_demos.hdf5') | |
| request.addfinalizer(remove_test_dir) | |