| |
|
| |
|
| |
|
| |
|
| |
|
| | from importlib import import_module |
| | import pkgutil |
| | import sys |
| | from . import layers |
| |
|
| |
|
| | def import_recursive(package): |
| | """ |
| | Takes a package and imports all modules underneath it |
| | """ |
| |
|
| | pkg_dir = package.__path__ |
| | module_location = package.__name__ |
| | for (_module_loader, name, ispkg) in pkgutil.iter_modules(pkg_dir): |
| | module_name = "{}.{}".format(module_location, name) |
| | module = import_module(module_name) |
| | if ispkg: |
| | import_recursive(module) |
| |
|
| |
|
| | def find_subclasses_recursively(base_cls, sub_cls): |
| | cur_sub_cls = base_cls.__subclasses__() |
| | sub_cls.update(cur_sub_cls) |
| | for cls in cur_sub_cls: |
| | find_subclasses_recursively(cls, sub_cls) |
| |
|
| |
|
| | import_recursive(sys.modules[__name__]) |
| |
|
| | model_layer_subcls = set() |
| | find_subclasses_recursively(layers.ModelLayer, model_layer_subcls) |
| |
|
| | for cls in list(model_layer_subcls): |
| | layers.register_layer(cls.__name__, cls) |
| |
|