| |
| import importlib |
| import os |
| import pkgutil |
| import warnings |
| from collections import namedtuple |
|
|
| import torch |
|
|
| if torch.__version__ != 'parrots': |
|
|
| def load_ext(name, funcs): |
| ext = importlib.import_module('mmcv.' + name) |
| for fun in funcs: |
| assert hasattr(ext, fun), f'{fun} miss in module {name}' |
| return ext |
| else: |
| from parrots import extension |
| from parrots.base import ParrotsException |
|
|
| has_return_value_ops = [ |
| 'nms', |
| 'softnms', |
| 'nms_match', |
| 'nms_rotated', |
| 'top_pool_forward', |
| 'top_pool_backward', |
| 'bottom_pool_forward', |
| 'bottom_pool_backward', |
| 'left_pool_forward', |
| 'left_pool_backward', |
| 'right_pool_forward', |
| 'right_pool_backward', |
| 'fused_bias_leakyrelu', |
| 'upfirdn2d', |
| 'ms_deform_attn_forward', |
| 'pixel_group', |
| 'contour_expand', |
| ] |
|
|
| def get_fake_func(name, e): |
|
|
| def fake_func(*args, **kwargs): |
| warnings.warn(f'{name} is not supported in parrots now') |
| raise e |
|
|
| return fake_func |
|
|
| def load_ext(name, funcs): |
| ExtModule = namedtuple('ExtModule', funcs) |
| ext_list = [] |
| lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| for fun in funcs: |
| try: |
| ext_fun = extension.load(fun, name, lib_dir=lib_root) |
| except ParrotsException as e: |
| if 'No element registered' not in e.message: |
| warnings.warn(e.message) |
| ext_fun = get_fake_func(fun, e) |
| ext_list.append(ext_fun) |
| else: |
| if fun in has_return_value_ops: |
| ext_list.append(ext_fun.op) |
| else: |
| ext_list.append(ext_fun.op_) |
| return ExtModule(*ext_list) |
|
|
|
|
| def check_ops_exist(): |
| ext_loader = pkgutil.find_loader('mmcv._ext') |
| return ext_loader is not None |
|
|