File size: 467 Bytes
25986db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch.nn as nn
# Here we use DistributedDataParallel(DDP) rather than DataParallel(DP) for multiple GPUs training
def is_multi_gpu(net):
return isinstance(net, (MultiGPU, nn.parallel.distributed.DistributedDataParallel))
class MultiGPU(nn.parallel.distributed.DistributedDataParallel):
def __getattr__(self, item):
try:
return super().__getattr__(item)
except:
pass
return getattr(self.module, item)
|