|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Factory methods to build models."""
|
|
|
| from typing import Optional
|
|
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.vision.configs import image_classification as classification_cfg
|
| from official.vision.configs import maskrcnn as maskrcnn_cfg
|
| from official.vision.configs import retinanet as retinanet_cfg
|
| from official.vision.configs import semantic_segmentation as segmentation_cfg
|
| from official.vision.modeling import backbones
|
| from official.vision.modeling import classification_model
|
| from official.vision.modeling import decoders
|
| from official.vision.modeling import maskrcnn_model
|
| from official.vision.modeling import retinanet_model
|
| from official.vision.modeling import segmentation_model
|
| from official.vision.modeling.heads import dense_prediction_heads
|
| from official.vision.modeling.heads import instance_heads
|
| from official.vision.modeling.heads import segmentation_heads
|
| from official.vision.modeling.layers import detection_generator
|
| from official.vision.modeling.layers import mask_sampler
|
| from official.vision.modeling.layers import roi_aligner
|
| from official.vision.modeling.layers import roi_generator
|
| from official.vision.modeling.layers import roi_sampler
|
|
|
|
|
| def build_classification_model(
|
| input_specs: tf_keras.layers.InputSpec,
|
| model_config: classification_cfg.ImageClassificationModel,
|
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
|
| skip_logits_layer: bool = False,
|
| backbone: Optional[tf_keras.Model] = None) -> tf_keras.Model:
|
| """Builds the classification model."""
|
| norm_activation_config = model_config.norm_activation
|
| if not backbone:
|
| backbone = backbones.factory.build_backbone(
|
| input_specs=input_specs,
|
| backbone_config=model_config.backbone,
|
| norm_activation_config=norm_activation_config,
|
| l2_regularizer=l2_regularizer)
|
|
|
| model = classification_model.ClassificationModel(
|
| backbone=backbone,
|
| num_classes=model_config.num_classes,
|
| input_specs=input_specs,
|
| dropout_rate=model_config.dropout_rate,
|
| kernel_initializer=model_config.kernel_initializer,
|
| kernel_regularizer=l2_regularizer,
|
| add_head_batch_norm=model_config.add_head_batch_norm,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| skip_logits_layer=skip_logits_layer)
|
| return model
|
|
|
|
|
| def build_maskrcnn(input_specs: tf_keras.layers.InputSpec,
|
| model_config: maskrcnn_cfg.MaskRCNN,
|
| l2_regularizer: Optional[
|
| tf_keras.regularizers.Regularizer] = None,
|
| backbone: Optional[tf_keras.Model] = None,
|
| decoder: Optional[tf_keras.Model] = None) -> tf_keras.Model:
|
| """Builds Mask R-CNN model."""
|
| norm_activation_config = model_config.norm_activation
|
| if not backbone:
|
| backbone = backbones.factory.build_backbone(
|
| input_specs=input_specs,
|
| backbone_config=model_config.backbone,
|
| norm_activation_config=norm_activation_config,
|
| l2_regularizer=l2_regularizer)
|
| backbone_features = backbone(tf_keras.Input(input_specs.shape[1:]))
|
|
|
| if not decoder:
|
| decoder = decoders.factory.build_decoder(
|
| input_specs=backbone.output_specs,
|
| model_config=model_config,
|
| l2_regularizer=l2_regularizer)
|
|
|
| rpn_head_config = model_config.rpn_head
|
| roi_generator_config = model_config.roi_generator
|
| roi_sampler_config = model_config.roi_sampler
|
| roi_aligner_config = model_config.roi_aligner
|
| detection_head_config = model_config.detection_head
|
| generator_config = model_config.detection_generator
|
| num_anchors_per_location = (
|
| len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
|
|
|
| rpn_head = dense_prediction_heads.RPNHead(
|
| min_level=model_config.min_level,
|
| max_level=model_config.max_level,
|
| num_anchors_per_location=num_anchors_per_location,
|
| num_convs=rpn_head_config.num_convs,
|
| num_filters=rpn_head_config.num_filters,
|
| use_separable_conv=rpn_head_config.use_separable_conv,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer)
|
|
|
| detection_head = instance_heads.DetectionHead(
|
| num_classes=model_config.num_classes,
|
| num_convs=detection_head_config.num_convs,
|
| num_filters=detection_head_config.num_filters,
|
| use_separable_conv=detection_head_config.use_separable_conv,
|
| num_fcs=detection_head_config.num_fcs,
|
| fc_dims=detection_head_config.fc_dims,
|
| class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer,
|
| name='detection_head')
|
|
|
| if decoder:
|
| decoder_features = decoder(backbone_features)
|
| rpn_head(decoder_features)
|
|
|
| if roi_sampler_config.cascade_iou_thresholds:
|
| detection_head_cascade = [detection_head]
|
| for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)):
|
| detection_head = instance_heads.DetectionHead(
|
| num_classes=model_config.num_classes,
|
| num_convs=detection_head_config.num_convs,
|
| num_filters=detection_head_config.num_filters,
|
| use_separable_conv=detection_head_config.use_separable_conv,
|
| num_fcs=detection_head_config.num_fcs,
|
| fc_dims=detection_head_config.fc_dims,
|
| class_agnostic_bbox_pred=detection_head_config
|
| .class_agnostic_bbox_pred,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer,
|
| name='detection_head_{}'.format(cascade_num + 1))
|
|
|
| detection_head_cascade.append(detection_head)
|
| detection_head = detection_head_cascade
|
|
|
| roi_generator_obj = roi_generator.MultilevelROIGenerator(
|
| pre_nms_top_k=roi_generator_config.pre_nms_top_k,
|
| pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
|
| pre_nms_min_size_threshold=(
|
| roi_generator_config.pre_nms_min_size_threshold),
|
| nms_iou_threshold=roi_generator_config.nms_iou_threshold,
|
| num_proposals=roi_generator_config.num_proposals,
|
| test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
|
| test_pre_nms_score_threshold=(
|
| roi_generator_config.test_pre_nms_score_threshold),
|
| test_pre_nms_min_size_threshold=(
|
| roi_generator_config.test_pre_nms_min_size_threshold),
|
| test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
|
| test_num_proposals=roi_generator_config.test_num_proposals,
|
| use_batched_nms=roi_generator_config.use_batched_nms)
|
|
|
| roi_sampler_cascade = []
|
| roi_sampler_obj = roi_sampler.ROISampler(
|
| mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
|
| num_sampled_rois=roi_sampler_config.num_sampled_rois,
|
| foreground_fraction=roi_sampler_config.foreground_fraction,
|
| foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
|
| background_iou_high_threshold=(
|
| roi_sampler_config.background_iou_high_threshold),
|
| background_iou_low_threshold=(
|
| roi_sampler_config.background_iou_low_threshold))
|
| roi_sampler_cascade.append(roi_sampler_obj)
|
|
|
| if roi_sampler_config.cascade_iou_thresholds:
|
| for iou in roi_sampler_config.cascade_iou_thresholds:
|
| roi_sampler_obj = roi_sampler.ROISampler(
|
| mix_gt_boxes=False,
|
| num_sampled_rois=roi_sampler_config.num_sampled_rois,
|
| foreground_iou_threshold=iou,
|
| background_iou_high_threshold=iou,
|
| background_iou_low_threshold=0.0,
|
| skip_subsampling=True)
|
| roi_sampler_cascade.append(roi_sampler_obj)
|
|
|
| roi_aligner_obj = roi_aligner.MultilevelROIAligner(
|
| crop_size=roi_aligner_config.crop_size,
|
| sample_offset=roi_aligner_config.sample_offset)
|
|
|
| detection_generator_obj = detection_generator.DetectionGenerator(
|
| apply_nms=generator_config.apply_nms,
|
| pre_nms_top_k=generator_config.pre_nms_top_k,
|
| pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
|
| nms_iou_threshold=generator_config.nms_iou_threshold,
|
| max_num_detections=generator_config.max_num_detections,
|
| nms_version=generator_config.nms_version,
|
| use_cpu_nms=generator_config.use_cpu_nms,
|
| soft_nms_sigma=generator_config.soft_nms_sigma,
|
| use_sigmoid_probability=generator_config.use_sigmoid_probability)
|
|
|
| if model_config.include_mask:
|
| mask_head = instance_heads.MaskHead(
|
| num_classes=model_config.num_classes,
|
| upsample_factor=model_config.mask_head.upsample_factor,
|
| num_convs=model_config.mask_head.num_convs,
|
| num_filters=model_config.mask_head.num_filters,
|
| use_separable_conv=model_config.mask_head.use_separable_conv,
|
| activation=model_config.norm_activation.activation,
|
| norm_momentum=model_config.norm_activation.norm_momentum,
|
| norm_epsilon=model_config.norm_activation.norm_epsilon,
|
| kernel_regularizer=l2_regularizer,
|
| class_agnostic=model_config.mask_head.class_agnostic)
|
|
|
| mask_sampler_obj = mask_sampler.MaskSampler(
|
| mask_target_size=(
|
| model_config.mask_roi_aligner.crop_size *
|
| model_config.mask_head.upsample_factor),
|
| num_sampled_masks=model_config.mask_sampler.num_sampled_masks)
|
|
|
| mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
|
| crop_size=model_config.mask_roi_aligner.crop_size,
|
| sample_offset=model_config.mask_roi_aligner.sample_offset)
|
| else:
|
| mask_head = None
|
| mask_sampler_obj = None
|
| mask_roi_aligner_obj = None
|
|
|
| model = maskrcnn_model.MaskRCNNModel(
|
| backbone=backbone,
|
| decoder=decoder,
|
| rpn_head=rpn_head,
|
| detection_head=detection_head,
|
| roi_generator=roi_generator_obj,
|
| roi_sampler=roi_sampler_cascade,
|
| roi_aligner=roi_aligner_obj,
|
| detection_generator=detection_generator_obj,
|
| mask_head=mask_head,
|
| mask_sampler=mask_sampler_obj,
|
| mask_roi_aligner=mask_roi_aligner_obj,
|
| class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
|
| cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
|
| min_level=model_config.min_level,
|
| max_level=model_config.max_level,
|
| num_scales=model_config.anchor.num_scales,
|
| aspect_ratios=model_config.anchor.aspect_ratios,
|
| anchor_size=model_config.anchor.anchor_size,
|
| outer_boxes_scale=model_config.outer_boxes_scale)
|
| return model
|
|
|
|
|
| def build_retinanet(
|
| input_specs: tf_keras.layers.InputSpec,
|
| model_config: retinanet_cfg.RetinaNet,
|
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
|
| backbone: Optional[tf_keras.Model] = None,
|
| decoder: Optional[tf_keras.Model] = None
|
| ) -> tf_keras.Model:
|
| """Builds RetinaNet model."""
|
| norm_activation_config = model_config.norm_activation
|
| if not backbone:
|
| backbone = backbones.factory.build_backbone(
|
| input_specs=input_specs,
|
| backbone_config=model_config.backbone,
|
| norm_activation_config=norm_activation_config,
|
| l2_regularizer=l2_regularizer)
|
| backbone_features = backbone(tf_keras.Input(input_specs.shape[1:]))
|
|
|
| if not decoder:
|
| decoder = decoders.factory.build_decoder(
|
| input_specs=backbone.output_specs,
|
| model_config=model_config,
|
| l2_regularizer=l2_regularizer)
|
|
|
| head_config = model_config.head
|
| generator_config = model_config.detection_generator
|
| num_anchors_per_location = (
|
| len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)
|
|
|
| head = dense_prediction_heads.RetinaNetHead(
|
| min_level=model_config.min_level,
|
| max_level=model_config.max_level,
|
| num_classes=model_config.num_classes,
|
| num_anchors_per_location=num_anchors_per_location,
|
| num_convs=head_config.num_convs,
|
| num_filters=head_config.num_filters,
|
| attribute_heads=[
|
| cfg.as_dict() for cfg in (head_config.attribute_heads or [])
|
| ],
|
| share_classification_heads=head_config.share_classification_heads,
|
| use_separable_conv=head_config.use_separable_conv,
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer,
|
| share_level_convs=head_config.share_level_convs,
|
| )
|
|
|
|
|
| if decoder:
|
| decoder_features = decoder(backbone_features)
|
| _ = head(decoder_features)
|
|
|
|
|
| tflite_post_processing_config = (
|
| generator_config.tflite_post_processing.as_dict()
|
| )
|
| tflite_post_processing_config['input_image_size'] = (
|
| input_specs.shape[1],
|
| input_specs.shape[2],
|
| )
|
| detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
|
| apply_nms=generator_config.apply_nms,
|
| pre_nms_top_k=generator_config.pre_nms_top_k,
|
| pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
|
| nms_iou_threshold=generator_config.nms_iou_threshold,
|
| max_num_detections=generator_config.max_num_detections,
|
| nms_version=generator_config.nms_version,
|
| use_cpu_nms=generator_config.use_cpu_nms,
|
| soft_nms_sigma=generator_config.soft_nms_sigma,
|
| tflite_post_processing_config=tflite_post_processing_config,
|
| return_decoded=generator_config.return_decoded,
|
| use_class_agnostic_nms=generator_config.use_class_agnostic_nms,
|
| box_coder_weights=generator_config.box_coder_weights,
|
| )
|
|
|
| model = retinanet_model.RetinaNetModel(
|
| backbone,
|
| decoder,
|
| head,
|
| detection_generator_obj,
|
| min_level=model_config.min_level,
|
| max_level=model_config.max_level,
|
| num_scales=model_config.anchor.num_scales,
|
| aspect_ratios=model_config.anchor.aspect_ratios,
|
| anchor_size=model_config.anchor.anchor_size)
|
| return model
|
|
|
|
|
| def build_segmentation_model(
|
| input_specs: tf_keras.layers.InputSpec,
|
| model_config: segmentation_cfg.SemanticSegmentationModel,
|
| l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
|
| backbone: Optional[tf_keras.Model] = None,
|
| decoder: Optional[tf_keras.Model] = None
|
| ) -> tf_keras.Model:
|
| """Builds Segmentation model."""
|
| norm_activation_config = model_config.norm_activation
|
| if not backbone:
|
| backbone = backbones.factory.build_backbone(
|
| input_specs=input_specs,
|
| backbone_config=model_config.backbone,
|
| norm_activation_config=norm_activation_config,
|
| l2_regularizer=l2_regularizer)
|
|
|
| if not decoder:
|
| decoder = decoders.factory.build_decoder(
|
| input_specs=backbone.output_specs,
|
| model_config=model_config,
|
| l2_regularizer=l2_regularizer)
|
|
|
| head_config = model_config.head
|
|
|
| head = segmentation_heads.SegmentationHead(
|
| num_classes=model_config.num_classes,
|
| level=head_config.level,
|
| num_convs=head_config.num_convs,
|
| prediction_kernel_size=head_config.prediction_kernel_size,
|
| num_filters=head_config.num_filters,
|
| use_depthwise_convolution=head_config.use_depthwise_convolution,
|
| upsample_factor=head_config.upsample_factor,
|
| feature_fusion=head_config.feature_fusion,
|
| low_level=head_config.low_level,
|
| low_level_num_filters=head_config.low_level_num_filters,
|
| activation=norm_activation_config.activation,
|
| logit_activation=head_config.logit_activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer)
|
|
|
| mask_scoring_head = None
|
| if model_config.mask_scoring_head:
|
| mask_scoring_head = segmentation_heads.MaskScoring(
|
| num_classes=model_config.num_classes,
|
| **model_config.mask_scoring_head.as_dict(),
|
| activation=norm_activation_config.activation,
|
| use_sync_bn=norm_activation_config.use_sync_bn,
|
| norm_momentum=norm_activation_config.norm_momentum,
|
| norm_epsilon=norm_activation_config.norm_epsilon,
|
| kernel_regularizer=l2_regularizer)
|
|
|
| model = segmentation_model.SegmentationModel(
|
| backbone, decoder, head, mask_scoring_head=mask_scoring_head)
|
| return model
|
|
|