Spaces:
Running
Running
| from typing import Callable, Any, List, TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from ding.data.buffer.buffer import Buffer | |
| def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable: | |
| """ | |
| Overview: | |
| This middleware aims to check staleness before each sample operation, | |
| staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is, | |
| If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible. | |
| Arguments: | |
| - max_staleness (:obj:`int`): The maximum legal span between the time of collecting and time of sampling. | |
| """ | |
| def push(next: Callable, data: Any, *args, **kwargs) -> Any: | |
| assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[ | |
| 'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': <iter>}" | |
| return next(data, *args, **kwargs) | |
| def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]: | |
| delete_index = [] | |
| for i, item in enumerate(buffer_.storage): | |
| index, meta = item.index, item.meta | |
| staleness = train_iter_sample_data - meta['train_iter_data_collected'] | |
| meta['staleness'] = staleness | |
| if staleness > max_staleness: | |
| delete_index.append(index) | |
| for index in delete_index: | |
| buffer_.delete(index) | |
| data = next(*args, **kwargs) | |
| return data | |
| def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any: | |
| if action == "push": | |
| return push(next, *args, **kwargs) | |
| elif action == "sample": | |
| return sample(next, *args, **kwargs) | |
| return next(*args, **kwargs) | |
| return _staleness_check | |