|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Detection input and model functions for serving/inference."""
|
|
|
| import math
|
| from typing import Mapping, Tuple
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.vision import configs
|
| from official.vision.modeling import factory
|
| from official.vision.ops import anchor
|
| from official.vision.ops import box_ops
|
| from official.vision.ops import preprocess_ops
|
| from official.vision.serving import export_base
|
|
|
|
|
| class DetectionModule(export_base.ExportModule):
|
| """Detection Module."""
|
|
|
| @property
|
| def _padded_size(self):
|
| if self.params.task.train_data.parser.pad:
|
| return preprocess_ops.compute_padded_size(
|
| self._input_image_size, 2**self.params.task.model.max_level
|
| )
|
| else:
|
| return self._input_image_size
|
|
|
| def _build_model(self):
|
|
|
| nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
|
| nms_version = self.params.task.model.detection_generator.nms_version
|
| if (self._batch_size is None and
|
| nms_version not in nms_versions_supporting_dynamic_batch_size):
|
| logging.info('nms_version is set to `batched` because `%s` '
|
| 'does not support with dynamic batch size.', nms_version)
|
| self.params.task.model.detection_generator.nms_version = 'batched'
|
|
|
| input_specs = tf_keras.layers.InputSpec(shape=[
|
| self._batch_size, *self._padded_size, 3])
|
|
|
| if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
|
| model = factory.build_maskrcnn(
|
| input_specs=input_specs, model_config=self.params.task.model)
|
| elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
|
| model = factory.build_retinanet(
|
| input_specs=input_specs, model_config=self.params.task.model)
|
| else:
|
| raise ValueError('Detection module not implemented for {} model.'.format(
|
| type(self.params.task.model)))
|
|
|
| return model
|
|
|
| def _build_anchor_boxes(self):
|
| """Builds and returns anchor boxes."""
|
| model_params = self.params.task.model
|
| input_anchor = anchor.build_anchor_generator(
|
| min_level=model_params.min_level,
|
| max_level=model_params.max_level,
|
| num_scales=model_params.anchor.num_scales,
|
| aspect_ratios=model_params.anchor.aspect_ratios,
|
| anchor_size=model_params.anchor.anchor_size)
|
| return input_anchor(image_size=self._padded_size)
|
|
|
| def _build_inputs(self, image):
|
| """Builds detection model inputs for serving."""
|
|
|
| image = preprocess_ops.normalize_image(
|
| image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
|
|
|
| image, image_info = preprocess_ops.resize_and_crop_image(
|
| image,
|
| self._input_image_size,
|
| padded_size=self._padded_size,
|
| aug_scale_min=1.0,
|
| aug_scale_max=1.0,
|
| keep_aspect_ratio=self.params.task.train_data.parser.keep_aspect_ratio,
|
| )
|
| anchor_boxes = self._build_anchor_boxes()
|
|
|
| return image, anchor_boxes, image_info
|
|
|
| def _normalize_coordinates(self, detections_dict, dict_keys, image_info):
|
| """Normalizes detection coordinates between 0 and 1.
|
|
|
| Args:
|
| detections_dict: Dictionary containing the output of the model prediction.
|
| dict_keys: Key names corresponding to the tensors of the output dictionary
|
| that we want to update.
|
| image_info: Tensor containing the details of the image resizing.
|
|
|
| Returns:
|
| detections_dict: Updated detection dictionary.
|
| """
|
| for key in dict_keys:
|
| if key not in detections_dict:
|
| continue
|
| detection_boxes = detections_dict[key] / tf.tile(
|
| image_info[:, 2:3, :], [1, 1, 2]
|
| )
|
| detections_dict[key] = box_ops.normalize_boxes(
|
| detection_boxes, image_info[:, 0:1, :]
|
| )
|
| detections_dict[key] = tf.clip_by_value(detections_dict[key], 0.0, 1.0)
|
|
|
| return detections_dict
|
|
|
| def preprocess(
|
| self, images: tf.Tensor
|
| ) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
|
| """Preprocesses inputs to be suitable for the model.
|
|
|
| Args:
|
| images: The images tensor.
|
| Returns:
|
| images: The images tensor cast to float.
|
| anchor_boxes: Dict mapping anchor levels to anchor boxes.
|
| image_info: Tensor containing the details of the image resizing.
|
|
|
| """
|
| model_params = self.params.task.model
|
| with tf.device('cpu:0'):
|
| images = tf.cast(images, dtype=tf.float32)
|
|
|
|
|
| images_spec = tf.TensorSpec(shape=self._padded_size + [3],
|
| dtype=tf.float32)
|
|
|
| num_anchors = model_params.anchor.num_scales * len(
|
| model_params.anchor.aspect_ratios) * 4
|
| anchor_shapes = []
|
| for level in range(model_params.min_level, model_params.max_level + 1):
|
| anchor_level_spec = tf.TensorSpec(
|
| shape=[
|
| math.ceil(self._padded_size[0] / 2**level),
|
| math.ceil(self._padded_size[1] / 2**level),
|
| num_anchors,
|
| ],
|
| dtype=tf.float32)
|
| anchor_shapes.append((str(level), anchor_level_spec))
|
|
|
| image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
|
|
|
| images, anchor_boxes, image_info = tf.nest.map_structure(
|
| tf.identity,
|
| tf.map_fn(
|
| self._build_inputs,
|
| elems=images,
|
| fn_output_signature=(images_spec, dict(anchor_shapes),
|
| image_info_spec),
|
| parallel_iterations=32))
|
|
|
| return images, anchor_boxes, image_info
|
|
|
| def serve(self, images: tf.Tensor):
|
| """Casts image to float and runs inference.
|
|
|
| Args:
|
| images: uint8 Tensor of shape [batch_size, None, None, 3]
|
| Returns:
|
| Tensor holding detection output logits.
|
| """
|
|
|
|
|
|
|
| if self._input_type != 'tflite':
|
| images, anchor_boxes, image_info = self.preprocess(images)
|
| else:
|
| with tf.device('cpu:0'):
|
| anchor_boxes = self._build_anchor_boxes()
|
|
|
|
|
|
|
|
|
|
|
| image_info = tf.convert_to_tensor([[
|
| self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
|
| ]],
|
| dtype=tf.float32)
|
| input_image_shape = image_info[:, 1, :]
|
|
|
|
|
|
|
|
|
| model_call_kwargs = {
|
| 'images': images,
|
| 'image_shape': input_image_shape,
|
| 'anchor_boxes': anchor_boxes,
|
| 'training': False,
|
| }
|
| if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
|
| model_call_kwargs['output_intermediate_features'] = (
|
| self.params.task.export_config.output_intermediate_features
|
| )
|
| detections = self.model.call(**model_call_kwargs)
|
|
|
| if self.params.task.model.detection_generator.apply_nms:
|
|
|
|
|
| if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
|
| export_config = self.params.task.export_config
|
|
|
| if export_config.output_normalized_coordinates:
|
| keys = ['detection_boxes', 'detection_outer_boxes']
|
| detections = self._normalize_coordinates(detections, keys, image_info)
|
|
|
|
|
|
|
|
|
| if export_config.cast_num_detections_to_float:
|
| detections['num_detections'] = tf.cast(
|
| detections['num_detections'], dtype=tf.float32)
|
| if export_config.cast_detection_classes_to_float:
|
| detections['detection_classes'] = tf.cast(
|
| detections['detection_classes'], dtype=tf.float32)
|
|
|
| final_outputs = {
|
| 'detection_boxes': detections['detection_boxes'],
|
| 'detection_scores': detections['detection_scores'],
|
| 'detection_classes': detections['detection_classes'],
|
| 'num_detections': detections['num_detections']
|
| }
|
| if 'detection_outer_boxes' in detections:
|
| final_outputs['detection_outer_boxes'] = (
|
| detections['detection_outer_boxes'])
|
| else:
|
|
|
| if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
|
| export_config = self.params.task.export_config
|
|
|
| if export_config.output_normalized_coordinates:
|
| keys = ['decoded_boxes']
|
| detections = self._normalize_coordinates(detections, keys, image_info)
|
| final_outputs = {
|
| 'decoded_boxes': detections['decoded_boxes'],
|
| 'decoded_box_scores': detections['decoded_box_scores']
|
| }
|
|
|
| if 'detection_masks' in detections.keys():
|
| final_outputs['detection_masks'] = detections['detection_masks']
|
| if (
|
| isinstance(self.params.task.model, configs.retinanet.RetinaNet)
|
| and self.params.task.export_config.output_intermediate_features
|
| ):
|
| final_outputs.update(
|
| {
|
| k: v
|
| for k, v in detections.items()
|
| if k.startswith('backbone_') or k.startswith('decoder_')
|
| }
|
| )
|
|
|
| if self.params.task.model.detection_generator.nms_version != 'tflite':
|
| final_outputs.update({'image_info': image_info})
|
| return final_outputs
|
|
|