Spaces:
Running
Running
| from typing import Union, Dict, Any, List | |
| from abc import ABC, abstractmethod | |
| import copy | |
| from easydict import EasyDict | |
| from ding.utils import import_module, BUFFER_REGISTRY | |
| class IBuffer(ABC): | |
| r""" | |
| Overview: | |
| Buffer interface | |
| Interfaces: | |
| default_config, push, update, sample, clear, count, state_dict, load_state_dict | |
| """ | |
| def default_config(cls) -> EasyDict: | |
| r""" | |
| Overview: | |
| Default config of this buffer class. | |
| Returns: | |
| - default_config (:obj:`EasyDict`) | |
| """ | |
| cfg = EasyDict(copy.deepcopy(cls.config)) | |
| cfg.cfg_type = cls.__name__ + 'Dict' | |
| return cfg | |
| def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: | |
| r""" | |
| Overview: | |
| Push a data into buffer. | |
| Arguments: | |
| - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ | |
| (in `Any` type), or many(int `List[Any]` type). | |
| - cur_collector_envstep (:obj:`int`): Collector's current env step. | |
| """ | |
| raise NotImplementedError | |
| def update(self, info: Dict[str, list]) -> None: | |
| r""" | |
| Overview: | |
| Update data info, e.g. priority. | |
| Arguments: | |
| - info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type. | |
| """ | |
| raise NotImplementedError | |
| def sample(self, batch_size: int, cur_learner_iter: int) -> list: | |
| r""" | |
| Overview: | |
| Sample data with length ``batch_size``. | |
| Arguments: | |
| - size (:obj:`int`): The number of the data that will be sampled. | |
| - cur_learner_iter (:obj:`int`): Learner's current iteration. | |
| Returns: | |
| - sampled_data (:obj:`list`): A list of data with length `batch_size`. | |
| """ | |
| raise NotImplementedError | |
| def clear(self) -> None: | |
| """ | |
| Overview: | |
| Clear all the data and reset the related variables. | |
| """ | |
| raise NotImplementedError | |
| def count(self) -> int: | |
| """ | |
| Overview: | |
| Count how many valid datas there are in the buffer. | |
| Returns: | |
| - count (:obj:`int`): Number of valid data. | |
| """ | |
| raise NotImplementedError | |
| def save_data(self, file_name: str): | |
| """ | |
| Overview: | |
| Save buffer data into a file. | |
| Arguments: | |
| - file_name (:obj:`str`): file name of buffer data | |
| """ | |
| raise NotImplementedError | |
| def load_data(self, file_name: str): | |
| """ | |
| Overview: | |
| Load buffer data from a file. | |
| Arguments: | |
| - file_name (:obj:`str`): file name of buffer data | |
| """ | |
| raise NotImplementedError | |
| def state_dict(self) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Provide a state dict to keep a record of current buffer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ | |
| With the dict, one can easily reproduce the buffer. | |
| """ | |
| raise NotImplementedError | |
| def load_state_dict(self, _state_dict: Dict[str, Any]) -> None: | |
| """ | |
| Overview: | |
| Load state dict to reproduce the buffer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. | |
| """ | |
| raise NotImplementedError | |
| def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer: | |
| r""" | |
| Overview: | |
| Create a buffer according to cfg and other arguments. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Buffer config. | |
| ArgumentsKeys: | |
| - necessary: `type` | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| if cfg.type == 'naive': | |
| kwargs.pop('tb_logger', None) | |
| return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs) | |
| def get_buffer_cls(cfg: EasyDict) -> type: | |
| r""" | |
| Overview: | |
| Get a buffer class according to cfg. | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Buffer config. | |
| ArgumentsKeys: | |
| - necessary: `type` | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return BUFFER_REGISTRY.get(cfg.type) | |