Spaces:
Running
Running
| import random | |
| from typing import Callable, List | |
| from ding.data.buffer.buffer import BufferedData | |
| def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable: | |
| """ | |
| Overview: | |
| The middleware is designed to process the data in each group after sampling from the buffer. | |
| Arguments: | |
| - size_in_group (:obj:`int`): Sample size in each group. | |
| - ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true. | |
| - max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true. | |
| """ | |
| def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]: | |
| if not kwargs.get("groupby"): | |
| raise Exception("Group sample must be used when the `groupby` parameter is specified.") | |
| sampled_data = chain(*args, **kwargs) | |
| for i, grouped_data in enumerate(sampled_data): | |
| if ordered_in_group: | |
| if max_use_in_group: | |
| end = max(0, len(grouped_data) - size_in_group) + 1 | |
| else: | |
| end = len(grouped_data) | |
| start_idx = random.choice(range(end)) | |
| sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group] | |
| else: | |
| sampled_data[i] = random.sample(grouped_data, k=size_in_group) | |
| return sampled_data | |
| def _group_sample(action: str, chain: Callable, *args, **kwargs): | |
| if action == "sample": | |
| return sample(chain, *args, **kwargs) | |
| return chain(*args, **kwargs) | |
| return _group_sample | |