Spaces:
Sleeping
Sleeping
| import pytest | |
| import numpy as np | |
| import random | |
| import torch | |
| from ding.data.level_replay.level_sampler import LevelSampler | |
| def test_level_sampler(): | |
| num_seeds = 500 | |
| obs_shape = [3, 64, 64] | |
| action_shape = 15 | |
| collector_env_num = 16 | |
| level_replay_dict = dict( | |
| strategy='min_margin', | |
| score_transform='rank', | |
| temperature=0.1, | |
| ) | |
| N = 10 | |
| collector_sample_length = 160 | |
| train_seeds = [i for i in range(num_seeds)] | |
| level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict) | |
| value = torch.randn(collector_sample_length) | |
| reward = torch.randn(collector_sample_length) | |
| adv = torch.randn(collector_sample_length) | |
| done = torch.randn(collector_sample_length) | |
| logit = torch.randn(collector_sample_length, N) | |
| seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)] | |
| all_seeds = torch.Tensor( | |
| [seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))] | |
| ) | |
| train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds} | |
| level_sampler.update_with_rollouts(train_data, collector_env_num) | |
| sample_seed = level_sampler.sample() | |
| assert isinstance(sample_seed, int) | |