Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """RetinaNet.""" | |
| from typing import Any, Mapping, List, Optional, Union, Sequence | |
| # Import libraries | |
| import tensorflow as tf, tf_keras | |
| from official.vision.ops import anchor | |
| class RetinaNetModel(tf_keras.Model): | |
| """The RetinaNet model class.""" | |
| def __init__(self, | |
| backbone: tf_keras.Model, | |
| decoder: tf_keras.Model, | |
| head: tf_keras.layers.Layer, | |
| detection_generator: tf_keras.layers.Layer, | |
| min_level: Optional[int] = None, | |
| max_level: Optional[int] = None, | |
| num_scales: Optional[int] = None, | |
| aspect_ratios: Optional[List[float]] = None, | |
| anchor_size: Optional[float] = None, | |
| **kwargs): | |
| """Detection initialization function. | |
| Args: | |
| backbone: `tf_keras.Model` a backbone network. | |
| decoder: `tf_keras.Model` a decoder network. | |
| head: `RetinaNetHead`, the RetinaNet head. | |
| detection_generator: the detection generator. | |
| min_level: Minimum level in output feature maps. | |
| max_level: Maximum level in output feature maps. | |
| num_scales: A number representing intermediate scales added | |
| on each level. For instances, num_scales=2 adds one additional | |
| intermediate anchor scales [2^0, 2^0.5] on each level. | |
| aspect_ratios: A list representing the aspect raito | |
| anchors added on each level. The number indicates the ratio of width to | |
| height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors | |
| on each scale level. | |
| anchor_size: A number representing the scale of size of the base | |
| anchor to the feature stride 2^level. | |
| **kwargs: keyword arguments to be passed. | |
| """ | |
| super(RetinaNetModel, self).__init__(**kwargs) | |
| self._config_dict = { | |
| 'backbone': backbone, | |
| 'decoder': decoder, | |
| 'head': head, | |
| 'detection_generator': detection_generator, | |
| 'min_level': min_level, | |
| 'max_level': max_level, | |
| 'num_scales': num_scales, | |
| 'aspect_ratios': aspect_ratios, | |
| 'anchor_size': anchor_size, | |
| } | |
| self._backbone = backbone | |
| self._decoder = decoder | |
| self._head = head | |
| self._detection_generator = detection_generator | |
| def call(self, | |
| images: Union[tf.Tensor, Sequence[tf.Tensor]], | |
| image_shape: Optional[tf.Tensor] = None, | |
| anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, | |
| output_intermediate_features: bool = False, | |
| training: bool = None) -> Mapping[str, tf.Tensor]: | |
| """Forward pass of the RetinaNet model. | |
| Args: | |
| images: `Tensor` or a sequence of `Tensor`, the input batched images to | |
| the backbone network, whose shape(s) is [batch, height, width, 3]. If it | |
| is a sequence of `Tensor`, we will assume the anchors are generated | |
| based on the shape of the first image(s). | |
| image_shape: `Tensor`, the actual shape of the input images, whose shape | |
| is [batch, 2] where the last dimension is [height, width]. Note that | |
| this is the actual image shape excluding paddings. For example, images | |
| in the batch may be resized into different shapes before padding to the | |
| fixed size. | |
| anchor_boxes: a dict of tensors which includes multilevel anchors. | |
| - key: `str`, the level of the multilevel predictions. | |
| - values: `Tensor`, the anchor coordinates of a particular feature | |
| level, whose shape is [height_l, width_l, num_anchors_per_location]. | |
| output_intermediate_features: `bool` indicating whether to return the | |
| intermediate feature maps generated by backbone and decoder. | |
| training: `bool`, indicating whether it is in training mode. | |
| Returns: | |
| scores: a dict of tensors which includes scores of the predictions. | |
| - key: `str`, the level of the multilevel predictions. | |
| - values: `Tensor`, the box scores predicted from a particular feature | |
| level, whose shape is | |
| [batch, height_l, width_l, num_classes * num_anchors_per_location]. | |
| boxes: a dict of tensors which includes coordinates of the predictions. | |
| - key: `str`, the level of the multilevel predictions. | |
| - values: `Tensor`, the box coordinates predicted from a particular | |
| feature level, whose shape is | |
| [batch, height_l, width_l, 4 * num_anchors_per_location]. | |
| attributes: a dict of (attribute_name, attribute_predictions). Each | |
| attribute prediction is a dict that includes: | |
| - key: `str`, the level of the multilevel predictions. | |
| - values: `Tensor`, the attribute predictions from a particular | |
| feature level, whose shape is | |
| [batch, height_l, width_l, att_size * num_anchors_per_location]. | |
| """ | |
| outputs = {} | |
| # Feature extraction. | |
| features = self.backbone(images) | |
| if output_intermediate_features: | |
| outputs.update( | |
| {'backbone_{}'.format(k): v for k, v in features.items()}) | |
| if self.decoder: | |
| features = self.decoder(features) | |
| if output_intermediate_features: | |
| outputs.update( | |
| {'decoder_{}'.format(k): v for k, v in features.items()}) | |
| # Dense prediction. `raw_attributes` can be empty. | |
| raw_scores, raw_boxes, raw_attributes = self.head(features) | |
| if training: | |
| outputs.update({ | |
| 'cls_outputs': raw_scores, | |
| 'box_outputs': raw_boxes, | |
| }) | |
| if raw_attributes: | |
| outputs.update({'attribute_outputs': raw_attributes}) | |
| return outputs | |
| else: | |
| # Generate anchor boxes for this batch if not provided. | |
| if anchor_boxes is None: | |
| if isinstance(images, Sequence): | |
| primary_images = images[0] | |
| elif isinstance(images, tf.Tensor): | |
| primary_images = images | |
| else: | |
| raise ValueError( | |
| 'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.' | |
| .format(type(images))) | |
| _, image_height, image_width, _ = primary_images.get_shape().as_list() | |
| anchor_boxes = anchor.Anchor( | |
| min_level=self._config_dict['min_level'], | |
| max_level=self._config_dict['max_level'], | |
| num_scales=self._config_dict['num_scales'], | |
| aspect_ratios=self._config_dict['aspect_ratios'], | |
| anchor_size=self._config_dict['anchor_size'], | |
| image_size=(image_height, image_width)).multilevel_boxes | |
| for l in anchor_boxes: | |
| anchor_boxes[l] = tf.tile( | |
| tf.expand_dims(anchor_boxes[l], axis=0), | |
| [tf.shape(primary_images)[0], 1, 1, 1]) | |
| # Post-processing. | |
| final_results = self.detection_generator(raw_boxes, raw_scores, | |
| anchor_boxes, image_shape, | |
| raw_attributes) | |
| outputs.update({ | |
| 'cls_outputs': raw_scores, | |
| 'box_outputs': raw_boxes, | |
| }) | |
| def _update_decoded_results(): | |
| outputs.update({ | |
| 'decoded_boxes': final_results['decoded_boxes'], | |
| 'decoded_box_scores': final_results['decoded_box_scores'], | |
| }) | |
| if final_results.get('decoded_box_attributes') is not None: | |
| outputs['decoded_box_attributes'] = final_results[ | |
| 'decoded_box_attributes' | |
| ] | |
| if self.detection_generator.get_config()['apply_nms']: | |
| outputs.update({ | |
| 'detection_boxes': final_results['detection_boxes'], | |
| 'detection_scores': final_results['detection_scores'], | |
| 'detection_classes': final_results['detection_classes'], | |
| 'num_detections': final_results['num_detections'], | |
| }) | |
| # Users can choose to include the decoded results (boxes before NMS) in | |
| # the output tensor dict even if `apply_nms` is set to `True`. | |
| if self.detection_generator.get_config()['return_decoded']: | |
| _update_decoded_results() | |
| else: | |
| _update_decoded_results() | |
| if raw_attributes: | |
| outputs.update({ | |
| 'attribute_outputs': raw_attributes, | |
| 'detection_attributes': final_results['detection_attributes'], | |
| }) | |
| return outputs | |
| def checkpoint_items( | |
| self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]: | |
| """Returns a dictionary of items to be additionally checkpointed.""" | |
| items = dict(backbone=self.backbone, head=self.head) | |
| if self.decoder is not None: | |
| items.update(decoder=self.decoder) | |
| return items | |
| def backbone(self) -> tf_keras.Model: | |
| return self._backbone | |
| def decoder(self) -> tf_keras.Model: | |
| return self._decoder | |
| def head(self) -> tf_keras.layers.Layer: | |
| return self._head | |
| def detection_generator(self) -> tf_keras.layers.Layer: | |
| return self._detection_generator | |
| def get_config(self) -> Mapping[str, Any]: | |
| return self._config_dict | |
| def from_config(cls, config): | |
| return cls(**config) | |