| | |
| | import torch |
| | from torch.nn.parallel._functions import Scatter as OrigScatter |
| |
|
| | from ._functions import Scatter |
| | from .data_container import DataContainer |
| |
|
| |
|
| | def scatter(inputs, target_gpus, dim=0): |
| | """Scatter inputs to target gpus. |
| | |
| | The only difference from original :func:`scatter` is to add support for |
| | :type:`~mmcv.parallel.DataContainer`. |
| | """ |
| |
|
| | def scatter_map(obj): |
| | if isinstance(obj, torch.Tensor): |
| | if target_gpus != [-1]: |
| | return OrigScatter.apply(target_gpus, None, dim, obj) |
| | else: |
| | |
| | return Scatter.forward(target_gpus, obj) |
| | if isinstance(obj, DataContainer): |
| | if obj.cpu_only: |
| | return obj.data |
| | else: |
| | return Scatter.forward(target_gpus, obj.data) |
| | if isinstance(obj, tuple) and len(obj) > 0: |
| | return list(zip(*map(scatter_map, obj))) |
| | if isinstance(obj, list) and len(obj) > 0: |
| | out = list(map(list, zip(*map(scatter_map, obj)))) |
| | return out |
| | if isinstance(obj, dict) and len(obj) > 0: |
| | out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) |
| | return out |
| | return [obj for targets in target_gpus] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | try: |
| | return scatter_map(inputs) |
| | finally: |
| | scatter_map = None |
| |
|
| |
|
| | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): |
| | """Scatter with support for kwargs dictionary.""" |
| | inputs = scatter(inputs, target_gpus, dim) if inputs else [] |
| | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] |
| | if len(inputs) < len(kwargs): |
| | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) |
| | elif len(kwargs) < len(inputs): |
| | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) |
| | inputs = tuple(inputs) |
| | kwargs = tuple(kwargs) |
| | return inputs, kwargs |
| |
|