| import torch.nn as nn | |
| # Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training | |
| def is_multi_gpu(net): | |
| return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel)) | |
| class MultiGPU(nn.parallel.distributed.DistributedDataParallel): | |
| def __getattr__(self, item): | |
| try: | |
| return super().__getattr__(item) | |
| except: | |
| pass | |
| return getattr(self.module, item) | |