|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Backbone registers and factory method.
|
|
|
| One can regitered a new backbone model by the following two steps:
|
|
|
| 1 Import the factory and register the build in the backbone file.
|
| 2 Import the backbone class and add a build in __init__.py.
|
|
|
| ```
|
| # my_backbone.py
|
|
|
| from modeling.backbones import factory
|
|
|
| class MyBackbone():
|
| ...
|
|
|
| @factory.register_backbone_builder('my_backbone')
|
| def build_my_backbone():
|
| return MyBackbone()
|
|
|
| # backbones/__init__.py adds import
|
| from modeling.backbones.my_backbone import MyBackbone
|
| ```
|
|
|
| If one wants the MyBackbone class to be used only by those binary
|
| then don't imported the backbone module in backbones/__init__.py, but import it
|
| in place that uses it.
|
|
|
|
|
| """
|
| from typing import Sequence, Union
|
|
|
|
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.core import registry
|
| from official.modeling import hyperparams
|
|
|
|
|
| _REGISTERED_BACKBONE_CLS = {}
|
|
|
|
|
| def register_backbone_builder(key: str):
|
| """Decorates a builder of backbone class.
|
|
|
| The builder should be a Callable (a class or a function).
|
| This decorator supports registration of backbone builder as follows:
|
|
|
| ```
|
| class MyBackbone(tf_keras.Model):
|
| pass
|
|
|
| @register_backbone_builder('mybackbone')
|
| def builder(input_specs, config, l2_reg):
|
| return MyBackbone(...)
|
|
|
| # Builds a MyBackbone object.
|
| my_backbone = build_backbone_3d(input_specs, config, l2_reg)
|
| ```
|
|
|
| Args:
|
| key: A `str` of key to look up the builder.
|
|
|
| Returns:
|
| A callable for using as class decorator that registers the decorated class
|
| for creation from an instance of task_config_cls.
|
| """
|
| return registry.register(_REGISTERED_BACKBONE_CLS, key)
|
|
|
|
|
| def build_backbone(input_specs: Union[tf_keras.layers.InputSpec,
|
| Sequence[tf_keras.layers.InputSpec]],
|
| backbone_config: hyperparams.Config,
|
| norm_activation_config: hyperparams.Config,
|
| l2_regularizer: tf_keras.regularizers.Regularizer = None,
|
| **kwargs) -> tf_keras.Model:
|
| """Builds backbone from a config.
|
|
|
| Args:
|
| input_specs: A (sequence of) `tf_keras.layers.InputSpec` of input.
|
| backbone_config: A `OneOfConfig` of backbone config.
|
| norm_activation_config: A config for normalization/activation layer.
|
| l2_regularizer: A `tf_keras.regularizers.Regularizer` object. Default to
|
| None.
|
| **kwargs: Additional keyword args to be passed to backbone builder.
|
|
|
| Returns:
|
| A `tf_keras.Model` instance of the backbone.
|
| """
|
| backbone_builder = registry.lookup(_REGISTERED_BACKBONE_CLS,
|
| backbone_config.type)
|
|
|
| return backbone_builder(
|
| input_specs=input_specs,
|
| backbone_config=backbone_config,
|
| norm_activation_config=norm_activation_config,
|
| l2_regularizer=l2_regularizer,
|
| **kwargs)
|
|
|