| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import collections |
| from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format |
| from typing import Callable, Dict, Optional, Tuple, Type, Union, List |
|
|
|
|
| def cat_collate_tensor_fn(batch, *, collate_fn_map): |
| return torch.cat(batch, dim=0) |
|
|
|
|
| def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): |
| return [item for bb in batch for item in bb] |
|
|
|
|
| cat_collate_fn_map = default_collate_fn_map.copy() |
| cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn |
| cat_collate_fn_map[List] = cat_collate_list_fn |
| cat_collate_fn_map[type(None)] = lambda _, **kw: None |
|
|
|
|
| def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): |
| r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """ |
| elem = batch[0] |
| elem_type = type(elem) |
|
|
| if collate_fn_map is not None: |
| if elem_type in collate_fn_map: |
| return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) |
|
|
| for collate_type in collate_fn_map: |
| if isinstance(elem, collate_type): |
| return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) |
|
|
| if isinstance(elem, collections.abc.Mapping): |
| try: |
| return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) |
| except TypeError: |
| |
| return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
| return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) |
| elif isinstance(elem, collections.abc.Sequence): |
| transposed = list(zip(*batch)) |
|
|
| if isinstance(elem, tuple): |
| |
| return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] |
| else: |
| try: |
| return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) |
| except TypeError: |
| |
| return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] |
|
|
| raise TypeError(default_collate_err_msg_format.format(elem_type)) |
|
|