| | |
| | from itertools import chain |
| |
|
| | from torch.nn.parallel import DataParallel |
| |
|
| | from .scatter_gather import scatter_kwargs |
| |
|
| |
|
| | class MMDataParallel(DataParallel): |
| | """The DataParallel module that supports DataContainer. |
| | |
| | MMDataParallel has two main differences with PyTorch DataParallel: |
| | |
| | - It supports a custom type :class:`DataContainer` which allows more |
| | flexible control of input data during both GPU and CPU inference. |
| | - It implement two more APIs ``train_step()`` and ``val_step()``. |
| | |
| | Args: |
| | module (:class:`nn.Module`): Module to be encapsulated. |
| | device_ids (list[int]): Device IDS of modules to be scattered to. |
| | Defaults to None when GPU is not available. |
| | output_device (str | int): Device ID for output. Defaults to None. |
| | dim (int): Dimension used to scatter the data. Defaults to 0. |
| | """ |
| |
|
| | def __init__(self, *args, dim=0, **kwargs): |
| | super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs) |
| | self.dim = dim |
| |
|
| | def forward(self, *inputs, **kwargs): |
| | """Override the original forward function. |
| | |
| | The main difference lies in the CPU inference where the data in |
| | :class:`DataContainers` will still be gathered. |
| | """ |
| | if not self.device_ids: |
| | |
| | |
| | inputs, kwargs = self.scatter(inputs, kwargs, [-1]) |
| | return self.module(*inputs[0], **kwargs[0]) |
| | else: |
| | return super().forward(*inputs, **kwargs) |
| |
|
| | def scatter(self, inputs, kwargs, device_ids): |
| | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
| |
|
| | def train_step(self, *inputs, **kwargs): |
| | if not self.device_ids: |
| | |
| | |
| | inputs, kwargs = self.scatter(inputs, kwargs, [-1]) |
| | return self.module.train_step(*inputs[0], **kwargs[0]) |
| |
|
| | assert len(self.device_ids) == 1, \ |
| | ('MMDataParallel only supports single GPU training, if you need to' |
| | ' train with multiple GPUs, please use MMDistributedDataParallel' |
| | 'instead.') |
| |
|
| | for t in chain(self.module.parameters(), self.module.buffers()): |
| | if t.device != self.src_device_obj: |
| | raise RuntimeError( |
| | 'module must have its parameters and buffers ' |
| | f'on device {self.src_device_obj} (device_ids[0]) but ' |
| | f'found one of them on device: {t.device}') |
| |
|
| | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| | return self.module.train_step(*inputs[0], **kwargs[0]) |
| |
|
| | def val_step(self, *inputs, **kwargs): |
| | if not self.device_ids: |
| | |
| | |
| | inputs, kwargs = self.scatter(inputs, kwargs, [-1]) |
| | return self.module.val_step(*inputs[0], **kwargs[0]) |
| |
|
| | assert len(self.device_ids) == 1, \ |
| | ('MMDataParallel only supports single GPU training, if you need to' |
| | ' train with multiple GPUs, please use MMDistributedDataParallel' |
| | ' instead.') |
| |
|
| | for t in chain(self.module.parameters(), self.module.buffers()): |
| | if t.device != self.src_device_obj: |
| | raise RuntimeError( |
| | 'module must have its parameters and buffers ' |
| | f'on device {self.src_device_obj} (device_ids[0]) but ' |
| | f'found one of them on device: {t.device}') |
| |
|
| | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| | return self.module.val_step(*inputs[0], **kwargs[0]) |
| |
|