Spaces:
Running
Running
| import os | |
| import copy | |
| from typing import Union, Any, Optional, List | |
| import numpy as np | |
| import math | |
| import hickle | |
| from easydict import EasyDict | |
| from ding.worker.replay_buffer import IBuffer | |
| from ding.utils import LockContext, LockContextType, BUFFER_REGISTRY, build_logger | |
| from .utils import UsedDataRemover, PeriodicThruputMonitor | |
| class NaiveReplayBuffer(IBuffer): | |
| r""" | |
| Overview: | |
| Naive replay buffer, can store and sample data. | |
| An naive implementation of replay buffer with no priority or any other advanced features. | |
| This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like | |
| ``sample``, ``push``, ``clear`` are all mutual to each other. | |
| Interface: | |
| start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config | |
| Property: | |
| replay_buffer_size, push_count | |
| """ | |
| config = dict( | |
| type='naive', | |
| replay_buffer_size=10000, | |
| deepcopy=False, | |
| # default `False` for serial pipeline | |
| enable_track_used_data=False, | |
| periodic_thruput_seconds=60, | |
| ) | |
| def __init__( | |
| self, | |
| cfg: 'EasyDict', # noqa | |
| tb_logger: Optional['SummaryWriter'] = None, # noqa | |
| exp_name: Optional[str] = 'default_experiment', | |
| instance_name: Optional[str] = 'buffer', | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the buffer | |
| Arguments: | |
| - cfg (:obj:`dict`): Config dict. | |
| - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode. | |
| - exp_name (:obj:`Optional[str]`): Name of this experiment. | |
| - instance_name (:obj:`Optional[str]`): Name of this instance. | |
| """ | |
| self._exp_name = exp_name | |
| self._instance_name = instance_name | |
| self._cfg = cfg | |
| self._replay_buffer_size = self._cfg.replay_buffer_size | |
| self._deepcopy = self._cfg.deepcopy | |
| # ``_data`` is a circular queue to store data (full data or meta data) | |
| self._data = [None for _ in range(self._replay_buffer_size)] | |
| # Current valid data count, indicating how many elements in ``self._data`` is valid. | |
| self._valid_count = 0 | |
| # How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``. | |
| self._push_count = 0 | |
| # Point to the tail position where next data can be inserted, i.e. latest inserted data's next position. | |
| self._tail = 0 | |
| # Lock to guarantee thread safe | |
| self._lock = LockContext(type_=LockContextType.THREAD_LOCK) | |
| self._end_flag = False | |
| self._enable_track_used_data = self._cfg.enable_track_used_data | |
| if self._enable_track_used_data: | |
| self._used_data_remover = UsedDataRemover() | |
| if tb_logger is not None: | |
| self._logger, _ = build_logger( | |
| './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
| ) | |
| self._tb_logger = tb_logger | |
| else: | |
| self._logger, self._tb_logger = build_logger( | |
| './{}/log/{}'.format(self._exp_name, self._instance_name), | |
| self._instance_name, | |
| ) | |
| # Periodic thruput. Here by default, monitor range is 60 seconds. You can modify it for free. | |
| self._periodic_thruput_monitor = PeriodicThruputMonitor( | |
| self._instance_name, EasyDict(seconds=self._cfg.periodic_thruput_seconds), self._logger, self._tb_logger | |
| ) | |
| def start(self) -> None: | |
| """ | |
| Overview: | |
| Start the buffer's used_data_remover thread if enables track_used_data. | |
| """ | |
| if self._enable_track_used_data: | |
| self._used_data_remover.start() | |
| def close(self) -> None: | |
| """ | |
| Overview: | |
| Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. | |
| """ | |
| self.clear() | |
| if self._enable_track_used_data: | |
| self._used_data_remover.close() | |
| self._tb_logger.flush() | |
| self._tb_logger.close() | |
| def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: | |
| r""" | |
| Overview: | |
| Push a data into buffer. | |
| Arguments: | |
| - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ | |
| (in `Any` type), or many(int `List[Any]` type). | |
| - cur_collector_envstep (:obj:`int`): Collector's current env step. \ | |
| Not used in naive buffer, but preserved for compatibility. | |
| """ | |
| if isinstance(data, list): | |
| self._extend(data, cur_collector_envstep) | |
| self._periodic_thruput_monitor.push_data_count += len(data) | |
| else: | |
| self._append(data, cur_collector_envstep) | |
| self._periodic_thruput_monitor.push_data_count += 1 | |
| def sample(self, | |
| size: int, | |
| cur_learner_iter: int, | |
| sample_range: slice = None, | |
| replace: bool = False) -> Optional[list]: | |
| """ | |
| Overview: | |
| Sample data with length ``size``. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled. | |
| - cur_learner_iter (:obj:`int`): Learner's current iteration. \ | |
| Not used in naive buffer, but preserved for compatibility. | |
| - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ | |
| means only sample among the last 10 data | |
| - replace (:obj:`bool`): Whether sample with replacement | |
| Returns: | |
| - sample_data (:obj:`list`): A list of data with length ``size``. | |
| """ | |
| if size == 0: | |
| return [] | |
| can_sample = self._sample_check(size, replace) | |
| if not can_sample: | |
| return None | |
| with self._lock: | |
| indices = self._get_indices(size, sample_range, replace) | |
| sample_data = self._sample_with_indices(indices, cur_learner_iter) | |
| self._periodic_thruput_monitor.sample_data_count += len(sample_data) | |
| return sample_data | |
| def save_data(self, file_name: str): | |
| if not os.path.exists(os.path.dirname(file_name)): | |
| if os.path.dirname(file_name) != "": | |
| os.makedirs(os.path.dirname(file_name)) | |
| hickle.dump(py_obj=self._data, file_obj=file_name) | |
| def load_data(self, file_name: str): | |
| self.push(hickle.load(file_name), 0) | |
| def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None: | |
| r""" | |
| Overview: | |
| Append a data item into ``self._data``. | |
| Arguments: | |
| - ori_data (:obj:`Any`): The data which will be inserted. | |
| - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility. | |
| """ | |
| with self._lock: | |
| if self._deepcopy: | |
| data = copy.deepcopy(ori_data) | |
| else: | |
| data = ori_data | |
| self._push_count += 1 | |
| if self._data[self._tail] is None: | |
| self._valid_count += 1 | |
| self._periodic_thruput_monitor.valid_count = self._valid_count | |
| elif self._enable_track_used_data: | |
| self._used_data_remover.add_used_data(self._data[self._tail]) | |
| self._data[self._tail] = data | |
| self._tail = (self._tail + 1) % self._replay_buffer_size | |
| def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None: | |
| r""" | |
| Overview: | |
| Extend a data list into queue. | |
| Add two keys in each data item, you can refer to ``_append`` for details. | |
| Arguments: | |
| - ori_data (:obj:`List[Any]`): The data list. | |
| - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility. | |
| """ | |
| with self._lock: | |
| if self._deepcopy: | |
| data = copy.deepcopy(ori_data) | |
| else: | |
| data = ori_data | |
| length = len(data) | |
| # When updating ``_data`` and ``_use_count``, should consider two cases regarding | |
| # the relationship between "tail + data length" and "replay buffer size" to check whether | |
| # data will exceed beyond buffer's max length limitation. | |
| if self._tail + length <= self._replay_buffer_size: | |
| if self._valid_count != self._replay_buffer_size: | |
| self._valid_count += length | |
| self._periodic_thruput_monitor.valid_count = self._valid_count | |
| elif self._enable_track_used_data: | |
| for i in range(length): | |
| self._used_data_remover.add_used_data(self._data[self._tail + i]) | |
| self._push_count += length | |
| self._data[self._tail:self._tail + length] = data | |
| else: | |
| new_tail = self._tail | |
| data_start = 0 | |
| residual_num = len(data) | |
| while True: | |
| space = self._replay_buffer_size - new_tail | |
| L = min(space, residual_num) | |
| if self._valid_count != self._replay_buffer_size: | |
| self._valid_count += L | |
| self._periodic_thruput_monitor.valid_count = self._valid_count | |
| elif self._enable_track_used_data: | |
| for i in range(L): | |
| self._used_data_remover.add_used_data(self._data[new_tail + i]) | |
| self._push_count += L | |
| self._data[new_tail:new_tail + L] = data[data_start:data_start + L] | |
| residual_num -= L | |
| assert residual_num >= 0 | |
| if residual_num == 0: | |
| break | |
| else: | |
| new_tail = 0 | |
| data_start += L | |
| # Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer. | |
| self._tail = (self._tail + length) % self._replay_buffer_size | |
| def _sample_check(self, size: int, replace: bool = False) -> bool: | |
| r""" | |
| Overview: | |
| Check whether this buffer has more than `size` datas to sample. | |
| Arguments: | |
| - size (:obj:`int`): Number of data that will be sampled. | |
| - replace (:obj:`bool`): Whether sample with replacement. | |
| Returns: | |
| - can_sample (:obj:`bool`): Whether this buffer can sample enough data. | |
| """ | |
| if self._valid_count == 0: | |
| print("The buffer is empty") | |
| return False | |
| if self._valid_count < size and not replace: | |
| print( | |
| "No enough elements for sampling without replacement (expect: {} / current: {})".format( | |
| size, self._valid_count | |
| ) | |
| ) | |
| return False | |
| else: | |
| return True | |
| def update(self, info: dict) -> None: | |
| r""" | |
| Overview: | |
| Naive Buffer does not need to update any info, but this method is preserved for compatibility. | |
| """ | |
| print( | |
| '[BUFFER WARNING] Naive Buffer does not need to update any info, \ | |
| but `update` method is preserved for compatibility.' | |
| ) | |
| def clear(self) -> None: | |
| """ | |
| Overview: | |
| Clear all the data and reset the related variables. | |
| """ | |
| with self._lock: | |
| for i in range(len(self._data)): | |
| if self._data[i] is not None: | |
| if self._enable_track_used_data: | |
| self._used_data_remover.add_used_data(self._data[i]) | |
| self._data[i] = None | |
| self._valid_count = 0 | |
| self._periodic_thruput_monitor.valid_count = self._valid_count | |
| self._push_count = 0 | |
| self._tail = 0 | |
| def __del__(self) -> None: | |
| """ | |
| Overview: | |
| Call ``close`` to delete the object. | |
| """ | |
| self.close() | |
| def _get_indices(self, size: int, sample_range: slice = None, replace: bool = False) -> list: | |
| r""" | |
| Overview: | |
| Get the sample index list. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled | |
| - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ | |
| means only sample among the last 10 data | |
| Returns: | |
| - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. | |
| """ | |
| assert self._valid_count <= self._replay_buffer_size | |
| if self._valid_count == self._replay_buffer_size: | |
| tail = self._replay_buffer_size | |
| else: | |
| tail = self._tail | |
| if sample_range is None: | |
| indices = list(np.random.choice(a=tail, size=size, replace=replace)) | |
| else: | |
| indices = list(range(tail))[sample_range] | |
| indices = list(np.random.choice(indices, size=size, replace=replace)) | |
| return indices | |
| def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list: | |
| r""" | |
| Overview: | |
| Sample data with ``indices``. | |
| Arguments: | |
| - indices (:obj:`List[int]`): A list including all the sample indices. | |
| - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility. | |
| Returns: | |
| - data (:obj:`list`) Sampled data. | |
| """ | |
| data = [] | |
| for idx in indices: | |
| assert self._data[idx] is not None, idx | |
| if self._deepcopy: | |
| copy_data = copy.deepcopy(self._data[idx]) | |
| else: | |
| copy_data = self._data[idx] | |
| data.append(copy_data) | |
| return data | |
| def count(self) -> int: | |
| """ | |
| Overview: | |
| Count how many valid datas there are in the buffer. | |
| Returns: | |
| - count (:obj:`int`): Number of valid data. | |
| """ | |
| return self._valid_count | |
| def state_dict(self) -> dict: | |
| """ | |
| Overview: | |
| Provide a state dict to keep a record of current buffer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ | |
| With the dict, one can easily reproduce the buffer. | |
| """ | |
| return { | |
| 'data': self._data, | |
| 'tail': self._tail, | |
| 'valid_count': self._valid_count, | |
| 'push_count': self._push_count, | |
| } | |
| def load_state_dict(self, _state_dict: dict) -> None: | |
| """ | |
| Overview: | |
| Load state dict to reproduce the buffer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. | |
| """ | |
| assert 'data' in _state_dict | |
| if set(_state_dict.keys()) == set(['data']): | |
| self._extend(_state_dict['data']) | |
| else: | |
| for k, v in _state_dict.items(): | |
| setattr(self, '_{}'.format(k), v) | |
| def replay_buffer_size(self) -> int: | |
| return self._replay_buffer_size | |
| def push_count(self) -> int: | |
| return self._push_count | |
| class ElasticReplayBuffer(NaiveReplayBuffer): | |
| r""" | |
| Overview: | |
| Elastic replay buffer, it stores data and support dynamically change the buffer size. | |
| An naive implementation of replay buffer with no priority or any other advanced features. | |
| This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like | |
| ``sample``, ``push``, ``clear`` are all mutual to each other. | |
| Interface: | |
| start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config | |
| Property: | |
| replay_buffer_size, push_count | |
| """ | |
| config = dict( | |
| type='elastic', | |
| replay_buffer_size=10000, | |
| deepcopy=False, | |
| # default `False` for serial pipeline | |
| enable_track_used_data=False, | |
| periodic_thruput_seconds=60, | |
| ) | |
| def __init__( | |
| self, | |
| cfg: 'EasyDict', # noqa | |
| tb_logger: Optional['SummaryWriter'] = None, # noqa | |
| exp_name: Optional[str] = 'default_experiment', | |
| instance_name: Optional[str] = 'buffer', | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the buffer | |
| Arguments: | |
| - cfg (:obj:`dict`): Config dict. | |
| - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode. | |
| - exp_name (:obj:`Optional[str]`): Name of this experiment. | |
| - instance_name (:obj:`Optional[str]`): Name of this instance. | |
| """ | |
| super().__init__(cfg, tb_logger, exp_name, instance_name) | |
| self._set_buffer_size = self._cfg.set_buffer_size | |
| self._current_buffer_size = self._set_buffer_size(0) # Set the buffer size at the 0-th envstep. | |
| # The variable 'current_buffer_size' restricts how many samples the buffer can use for sampling | |
| def _sample_check(self, size: int, replace: bool = False) -> bool: | |
| r""" | |
| Overview: | |
| Check whether this buffer has more than `size` datas to sample. | |
| Arguments: | |
| - size (:obj:`int`): Number of data that will be sampled. | |
| - replace (:obj:`bool`): Whether sample with replacement. | |
| Returns: | |
| - can_sample (:obj:`bool`): Whether this buffer can sample enough data. | |
| """ | |
| valid_count = min(self._valid_count, self._current_buffer_size) | |
| if valid_count == 0: | |
| print("The buffer is empty") | |
| return False | |
| if valid_count < size and not replace: | |
| print( | |
| "No enough elements for sampling without replacement (expect: {} / current: {})".format( | |
| size, self._valid_count | |
| ) | |
| ) | |
| return False | |
| else: | |
| return True | |
| def _get_indices(self, size: int, sample_range: slice = None, replace: bool = False) -> list: | |
| r""" | |
| Overview: | |
| Get the sample index list. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled. | |
| - replace (:obj:`bool`): Whether sample with replacement. | |
| Returns: | |
| - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. | |
| """ | |
| assert self._valid_count <= self._replay_buffer_size | |
| assert sample_range is None # not support | |
| range = min(self._valid_count, self._current_buffer_size) | |
| indices = list( | |
| (self._tail - 1 - np.random.choice(a=range, size=size, replace=replace)) % self._replay_buffer_size | |
| ) | |
| return indices | |
| def update(self, envstep): | |
| self._current_buffer_size = self._set_buffer_size(envstep) | |
| class SequenceReplayBuffer(NaiveReplayBuffer): | |
| r""" | |
| Overview: | |
| Interface: | |
| start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config | |
| Property: | |
| replay_buffer_size, push_count | |
| """ | |
| def sample( | |
| self, | |
| batch: int, | |
| sequence: int, | |
| cur_learner_iter: int, | |
| sample_range: slice = None, | |
| replace: bool = False | |
| ) -> Optional[list]: | |
| """ | |
| Overview: | |
| Sample data with length ``size``. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled. | |
| - sequence (:obj:`int`): The length of the sequence of a data that will be sampled. | |
| - cur_learner_iter (:obj:`int`): Learner's current iteration. \ | |
| Not used in naive buffer, but preserved for compatibility. | |
| - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ | |
| means only sample among the last 10 data | |
| - replace (:obj:`bool`): Whether sample with replacement | |
| Returns: | |
| - sample_data (:obj:`list`): A list of data with length ``size``. | |
| """ | |
| if batch == 0: | |
| return [] | |
| can_sample = self._sample_check(batch * sequence, replace) | |
| if not can_sample: | |
| return None | |
| with self._lock: | |
| indices = self._get_indices(batch, sequence, sample_range, replace) | |
| sample_data = self._sample_with_indices(indices, sequence, cur_learner_iter) | |
| self._periodic_thruput_monitor.sample_data_count += len(sample_data) | |
| return sample_data | |
| def _get_indices(self, size: int, sequence: int, sample_range: slice = None, replace: bool = False) -> list: | |
| r""" | |
| Overview: | |
| Get the sample index list. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled | |
| - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ | |
| means only sample among the last 10 data | |
| Returns: | |
| - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. | |
| """ | |
| assert self._valid_count <= self._replay_buffer_size | |
| if self._valid_count == self._replay_buffer_size: | |
| tail = self._replay_buffer_size | |
| else: | |
| tail = self._tail | |
| episodes = math.ceil(self._valid_count / 500) | |
| batch = 0 | |
| indices = [] | |
| if sample_range is None: | |
| while batch < size: | |
| episode = np.random.choice(episodes) | |
| length = tail - episode * 500 if tail - episode * 500 < 500 else 500 | |
| available = length - sequence | |
| if available < 1: | |
| continue | |
| list(range(episode * 500, episode * 500 + available)) | |
| indices.append(np.random.randint(episode * 500, episode * 500 + available + 1)) | |
| batch += 1 | |
| else: | |
| raise NotImplementedError("sample_range is not implemented in this version") | |
| return indices | |
| def _sample_with_indices(self, indices: List[int], sequence: int, cur_learner_iter: int) -> list: | |
| r""" | |
| Overview: | |
| Sample data with ``indices``. | |
| Arguments: | |
| - indices (:obj:`List[int]`): A list including all the sample indices. | |
| - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility. | |
| Returns: | |
| - data (:obj:`list`) Sampled data. | |
| """ | |
| data = [] | |
| for idx in indices: | |
| assert self._data[idx] is not None, idx | |
| if self._deepcopy: | |
| copy_data = copy.deepcopy(self._data[idx:idx + sequence]) | |
| else: | |
| copy_data = self._data[idx:idx + sequence] | |
| data.append(copy_data) | |
| return data | |