Spaces:
Build error
Build error
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Factory methods to build models.""" | |
| # Import libraries | |
| import tensorflow as tf, tf_keras | |
| from official.core import registry | |
| from official.vision.configs import video_classification as video_classification_cfg | |
| from official.vision.modeling import video_classification_model | |
| from official.vision.modeling import backbones | |
| _REGISTERED_MODEL_CLS = {} | |
| def register_model_builder(key: str): | |
| """Decorates a builder of model class. | |
| The builder should be a Callable (a class or a function). | |
| This decorator supports registration of backbone builder as follows: | |
| ``` | |
| class MyModel(tf_keras.Model): | |
| pass | |
| @register_backbone_builder('mybackbone') | |
| def builder(input_specs, config, l2_reg): | |
| return MyModel(...) | |
| # Builds a MyModel object. | |
| my_backbone = build_backbone_3d(input_specs, config, l2_reg) | |
| ``` | |
| Args: | |
| key: the key to look up the builder. | |
| Returns: | |
| A callable for use as class decorator that registers the decorated class | |
| for creation from an instance of model class. | |
| """ | |
| return registry.register(_REGISTERED_MODEL_CLS, key) | |
| def build_model( | |
| model_type: str, | |
| input_specs: tf_keras.layers.InputSpec, | |
| model_config: video_classification_cfg.hyperparams.Config, | |
| num_classes: int, | |
| l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model: | |
| """Builds backbone from a config. | |
| Args: | |
| model_type: string name of model type. It should be consistent with | |
| ModelConfig.model_type. | |
| input_specs: tf_keras.layers.InputSpec. | |
| model_config: a OneOfConfig. Model config. | |
| num_classes: number of classes. | |
| l2_regularizer: tf_keras.regularizers.Regularizer instance. Default to None. | |
| Returns: | |
| tf_keras.Model instance of the backbone. | |
| """ | |
| model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type) | |
| return model_builder(input_specs, model_config, num_classes, l2_regularizer) | |
| def build_video_classification_model( | |
| input_specs: tf_keras.layers.InputSpec, | |
| model_config: video_classification_cfg.VideoClassificationModel, | |
| num_classes: int, | |
| l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model: | |
| """Builds the video classification model.""" | |
| input_specs_dict = {'image': input_specs} | |
| norm_activation_config = model_config.norm_activation | |
| backbone = backbones.factory.build_backbone( | |
| input_specs=input_specs, | |
| backbone_config=model_config.backbone, | |
| norm_activation_config=norm_activation_config, | |
| l2_regularizer=l2_regularizer) | |
| model = video_classification_model.VideoClassificationModel( | |
| backbone=backbone, | |
| num_classes=num_classes, | |
| input_specs=input_specs_dict, | |
| dropout_rate=model_config.dropout_rate, | |
| aggregate_endpoints=model_config.aggregate_endpoints, | |
| kernel_regularizer=l2_regularizer, | |
| require_endpoints=model_config.require_endpoints) | |
| return model | |