| | |
| | from functools import partial |
| |
|
| | import torch |
| |
|
| | TORCH_VERSION = torch.__version__ |
| |
|
| |
|
| | def is_rocm_pytorch() -> bool: |
| | is_rocm = False |
| | if TORCH_VERSION != 'parrots': |
| | try: |
| | from torch.utils.cpp_extension import ROCM_HOME |
| | is_rocm = True if ((torch.version.hip is not None) and |
| | (ROCM_HOME is not None)) else False |
| | except ImportError: |
| | pass |
| | return is_rocm |
| |
|
| |
|
| | def _get_cuda_home(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.utils.build_extension import CUDA_HOME |
| | else: |
| | if is_rocm_pytorch(): |
| | from torch.utils.cpp_extension import ROCM_HOME |
| | CUDA_HOME = ROCM_HOME |
| | else: |
| | from torch.utils.cpp_extension import CUDA_HOME |
| | return CUDA_HOME |
| |
|
| |
|
| | def get_build_config(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.config import get_build_info |
| | return get_build_info() |
| | else: |
| | return torch.__config__.show() |
| |
|
| |
|
| | def _get_conv(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin |
| | else: |
| | from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin |
| | return _ConvNd, _ConvTransposeMixin |
| |
|
| |
|
| | def _get_dataloader(): |
| | if TORCH_VERSION == 'parrots': |
| | from torch.utils.data import DataLoader, PoolDataLoader |
| | else: |
| | from torch.utils.data import DataLoader |
| | PoolDataLoader = DataLoader |
| | return DataLoader, PoolDataLoader |
| |
|
| |
|
| | def _get_extension(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.utils.build_extension import BuildExtension, Extension |
| | CppExtension = partial(Extension, cuda=False) |
| | CUDAExtension = partial(Extension, cuda=True) |
| | else: |
| | from torch.utils.cpp_extension import (BuildExtension, CppExtension, |
| | CUDAExtension) |
| | return BuildExtension, CppExtension, CUDAExtension |
| |
|
| |
|
| | def _get_pool(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, |
| | _AdaptiveMaxPoolNd, _AvgPoolNd, |
| | _MaxPoolNd) |
| | else: |
| | from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, |
| | _AdaptiveMaxPoolNd, _AvgPoolNd, |
| | _MaxPoolNd) |
| | return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd |
| |
|
| |
|
| | def _get_norm(): |
| | if TORCH_VERSION == 'parrots': |
| | from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm |
| | SyncBatchNorm_ = torch.nn.SyncBatchNorm2d |
| | else: |
| | from torch.nn.modules.instancenorm import _InstanceNorm |
| | from torch.nn.modules.batchnorm import _BatchNorm |
| | SyncBatchNorm_ = torch.nn.SyncBatchNorm |
| | return _BatchNorm, _InstanceNorm, SyncBatchNorm_ |
| |
|
| |
|
| | _ConvNd, _ConvTransposeMixin = _get_conv() |
| | DataLoader, PoolDataLoader = _get_dataloader() |
| | BuildExtension, CppExtension, CUDAExtension = _get_extension() |
| | _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() |
| | _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() |
| |
|
| |
|
| | class SyncBatchNorm(SyncBatchNorm_): |
| |
|
| | def _check_input_dim(self, input): |
| | if TORCH_VERSION == 'parrots': |
| | if input.dim() < 2: |
| | raise ValueError( |
| | f'expected at least 2D input (got {input.dim()}D input)') |
| | else: |
| | super()._check_input_dim(input) |
| |
|