| | |
| | import importlib |
| | import os |
| | import pkgutil |
| | import warnings |
| | from collections import namedtuple |
| |
|
| | import torch |
| |
|
| | if torch.__version__ != 'parrots': |
| |
|
| | def load_ext(name, funcs): |
| | ext = importlib.import_module('mmcv.' + name) |
| | for fun in funcs: |
| | assert hasattr(ext, fun), f'{fun} miss in module {name}' |
| | return ext |
| | else: |
| | from parrots import extension |
| | from parrots.base import ParrotsException |
| |
|
| | has_return_value_ops = [ |
| | 'nms', |
| | 'softnms', |
| | 'nms_match', |
| | 'nms_rotated', |
| | 'top_pool_forward', |
| | 'top_pool_backward', |
| | 'bottom_pool_forward', |
| | 'bottom_pool_backward', |
| | 'left_pool_forward', |
| | 'left_pool_backward', |
| | 'right_pool_forward', |
| | 'right_pool_backward', |
| | 'fused_bias_leakyrelu', |
| | 'upfirdn2d', |
| | 'ms_deform_attn_forward', |
| | 'pixel_group', |
| | 'contour_expand', |
| | ] |
| |
|
| | def get_fake_func(name, e): |
| |
|
| | def fake_func(*args, **kwargs): |
| | warnings.warn(f'{name} is not supported in parrots now') |
| | raise e |
| |
|
| | return fake_func |
| |
|
| | def load_ext(name, funcs): |
| | ExtModule = namedtuple('ExtModule', funcs) |
| | ext_list = [] |
| | lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| | for fun in funcs: |
| | try: |
| | ext_fun = extension.load(fun, name, lib_dir=lib_root) |
| | except ParrotsException as e: |
| | if 'No element registered' not in e.message: |
| | warnings.warn(e.message) |
| | ext_fun = get_fake_func(fun, e) |
| | ext_list.append(ext_fun) |
| | else: |
| | if fun in has_return_value_ops: |
| | ext_list.append(ext_fun.op) |
| | else: |
| | ext_list.append(ext_fun.op_) |
| | return ExtModule(*ext_list) |
| |
|
| |
|
| | def check_ops_exist(): |
| | ext_loader = pkgutil.find_loader('mmcv._ext') |
| | return ext_loader is not None |
| |
|