Spaces:
Running
Running
| import random | |
| from typing import Callable, Union, List | |
| from ding.data.buffer import BufferedData | |
| from ding.utils import fastcopy | |
| def padding(policy="random"): | |
| """ | |
| Overview: | |
| Fill the nested buffer list to the same size as the largest list. | |
| The default policy `random` will randomly select data from each group | |
| and fill it into the current group list. | |
| Arguments: | |
| - policy (:obj:`str`): Padding policy, supports `random`, `none`. | |
| """ | |
| def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
| sampled_data = chain(*args, **kwargs) | |
| if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData): | |
| return sampled_data | |
| max_len = len(max(sampled_data, key=len)) | |
| for i, grouped_data in enumerate(sampled_data): | |
| group_len = len(grouped_data) | |
| if group_len == max_len: | |
| continue | |
| for _ in range(max_len - group_len): | |
| if policy == "random": | |
| sampled_data[i].append(fastcopy.copy(random.choice(grouped_data))) | |
| elif policy == "none": | |
| sampled_data[i].append(BufferedData(data=None, index=None, meta=None)) | |
| return sampled_data | |
| def _padding(action: str, chain: Callable, *args, **kwargs): | |
| if action == "sample": | |
| return sample(chain, *args, **kwargs) | |
| return chain(*args, **kwargs) | |
| return _padding | |