Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """This file contains loss builder classes used in the DeepLab model.""" | |
| import collections | |
| from typing import Any, Dict, Text, Tuple, Optional | |
| import tensorflow as tf | |
| from deeplab2 import common | |
| from deeplab2 import config_pb2 | |
| from deeplab2.model.loss import base_loss | |
| from deeplab2.model.loss import max_deeplab_loss | |
| def _create_loss_and_weight( | |
| loss_options: config_pb2.LossOptions.SingleLossOptions, gt_key: Text, | |
| pred_key: Text, weight_key: Text, **kwargs: Any) -> tf.keras.losses.Loss: | |
| """Creates a loss and its weight from loss options. | |
| Args: | |
| loss_options: Loss options as defined by | |
| config_pb2.LossOptions.SingleLossOptions or None. | |
| gt_key: A key to extract the ground-truth from a dictionary. | |
| pred_key: A key to extract the prediction from a dictionary. | |
| weight_key: A key to extract the per-pixel weights from a dictionary. | |
| **kwargs: Additional parameters to initialize the loss. | |
| Returns: | |
| A tuple of an instance of tf.keras.losses.Loss and its corresponding weight | |
| as an integer. | |
| Raises: | |
| ValueError: An error occurs when the loss name is not a valid loss. | |
| """ | |
| if loss_options is None: | |
| return None, 0 | |
| if loss_options.name == 'softmax_cross_entropy': | |
| return base_loss.TopKCrossEntropyLoss( | |
| gt_key, | |
| pred_key, | |
| weight_key, | |
| top_k_percent_pixels=loss_options.top_k_percent, | |
| **kwargs), loss_options.weight | |
| elif loss_options.name == 'l1': | |
| return base_loss.TopKGeneralLoss( | |
| base_loss.mean_absolute_error, | |
| gt_key, | |
| pred_key, | |
| weight_key, | |
| top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight | |
| elif loss_options.name == 'mse': | |
| return base_loss.TopKGeneralLoss( | |
| base_loss.mean_squared_error, | |
| gt_key, | |
| pred_key, | |
| weight_key, | |
| top_k_percent_pixels=loss_options.top_k_percent), loss_options.weight | |
| raise ValueError('Loss %s is not a valid loss.' % loss_options.name) | |
| class DeepLabFamilyLoss(tf.keras.layers.Layer): | |
| """This class contains code to build and call losses for DeepLabFamilyLoss.""" | |
| def __init__( | |
| self, | |
| loss_options: config_pb2.LossOptions, | |
| num_classes: Optional[int], | |
| ignore_label: Optional[int], | |
| thing_class_ids: Tuple[int]): | |
| """Initializes the losses for Panoptic-DeepLab. | |
| Args: | |
| loss_options: Loss options as defined by config_pb2.LossOptions. | |
| num_classes: An integer specifying the number of classes in the dataset. | |
| ignore_label: An optional integer specifying the ignore label or None. | |
| thing_class_ids: A tuple of length [N] containing N thing indices. | |
| """ | |
| super(DeepLabFamilyLoss, self).__init__(name='DeepLabFamilyLoss') | |
| # Single-term losses are losses that have only one loss term and thus each | |
| # loss function directly returns a single tensor as the loss value, as | |
| # opposed to multi-term losses that involve multiple terms and return a | |
| # dictionary of loss values. | |
| self._single_term_loss_func_and_weight_dict = collections.OrderedDict() | |
| self._extra_loss_names = [common.TOTAL_LOSS] | |
| if loss_options.HasField(common.SEMANTIC_LOSS): | |
| self._single_term_loss_func_and_weight_dict[ | |
| common.SEMANTIC_LOSS] = _create_loss_and_weight( | |
| loss_options.semantic_loss, | |
| common.GT_SEMANTIC_KEY, | |
| common.PRED_SEMANTIC_LOGITS_KEY, | |
| common.SEMANTIC_LOSS_WEIGHT_KEY, | |
| num_classes=num_classes, | |
| ignore_label=ignore_label) | |
| if loss_options.HasField(common.CENTER_LOSS): | |
| self._single_term_loss_func_and_weight_dict[ | |
| common.CENTER_LOSS] = _create_loss_and_weight( | |
| loss_options.center_loss, common.GT_INSTANCE_CENTER_KEY, | |
| common.PRED_CENTER_HEATMAP_KEY, common.CENTER_LOSS_WEIGHT_KEY) | |
| if loss_options.HasField(common.REGRESSION_LOSS): | |
| self._single_term_loss_func_and_weight_dict[ | |
| common.REGRESSION_LOSS] = _create_loss_and_weight( | |
| loss_options.regression_loss, common.GT_INSTANCE_REGRESSION_KEY, | |
| common.PRED_OFFSET_MAP_KEY, common.REGRESSION_LOSS_WEIGHT_KEY) | |
| # Currently, only used for Motion-DeepLab. | |
| if loss_options.HasField(common.MOTION_LOSS): | |
| self._single_term_loss_func_and_weight_dict[ | |
| common.MOTION_LOSS] = _create_loss_and_weight( | |
| loss_options.motion_loss, common.GT_FRAME_OFFSET_KEY, | |
| common.PRED_FRAME_OFFSET_MAP_KEY, | |
| common.FRAME_REGRESSION_LOSS_WEIGHT_KEY) | |
| # Next-frame regression loss used in ViP-DeepLab. | |
| if loss_options.HasField(common.NEXT_REGRESSION_LOSS): | |
| self._single_term_loss_func_and_weight_dict[ | |
| common.NEXT_REGRESSION_LOSS] = _create_loss_and_weight( | |
| loss_options.next_regression_loss, | |
| common.GT_NEXT_INSTANCE_REGRESSION_KEY, | |
| common.PRED_NEXT_OFFSET_MAP_KEY, | |
| common.NEXT_REGRESSION_LOSS_WEIGHT_KEY) | |
| # Multi-term losses that return dictionaries of loss terms. | |
| self._multi_term_losses = [] | |
| # MaXDeepLabLoss optionally returns four loss terms in total: | |
| # - common.PQ_STYLE_LOSS_CLASS_TERM | |
| # - common.PQ_STYLE_LOSS_MASK_DICE_TERM | |
| # - common.MASK_ID_CROSS_ENTROPY_LOSS | |
| # - common.INSTANCE_DISCRIMINATION_LOSS | |
| if any([loss_options.HasField('pq_style_loss'), | |
| loss_options.HasField('mask_id_cross_entropy_loss'), | |
| loss_options.HasField('instance_discrimination_loss')]): | |
| self._multi_term_losses.append(max_deeplab_loss.MaXDeepLabLoss( | |
| loss_options, ignore_label, thing_class_ids)) | |
| for multi_term_loss in self._multi_term_losses: | |
| self._extra_loss_names += multi_term_loss.loss_terms | |
| def get_loss_names(self): | |
| # Keep track of all the keys that will be returned in self.call(). | |
| loss_names = list(self._single_term_loss_func_and_weight_dict.keys()) | |
| return loss_names + self._extra_loss_names | |
| def call(self, y_true: Dict[Text, tf.Tensor], | |
| y_pred: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: | |
| """Performs the loss computations given ground-truth and predictions. | |
| The loss is computed for each sample separately. Currently, smoothed | |
| ground-truth labels are not supported. | |
| Args: | |
| y_true: A dictionary of tf.Tensor containing all ground-truth data to | |
| compute the loss. Depending on the configuration, the dict has to | |
| contain common.GT_SEMANTIC_KEY, and optionally | |
| common.GT_INSTANCE_CENTER_KEY, common.GT_INSTANCE_REGRESSION_KEY, and | |
| common.GT_FRAME_OFFSET_KEY. | |
| y_pred: A dicitionary of tf.Tensor containing all predictions to compute | |
| the loss. Depending on the configuration, the dict has to contain | |
| common.PRED_SEMANTIC_LOGITS_KEY, and optionally | |
| common.PRED_CENTER_HEATMAP_KEY, common.PRED_OFFSET_MAP_KEY, and | |
| common.PRED_FRAME_OFFSET_MAP_KEY. | |
| Returns: | |
| The loss as a dict of tf.Tensor, optionally containing the following: | |
| - common.SEMANTIC_LOSS: [batch]. | |
| - common.CENTER_LOSS: [batch]. | |
| - common.REGRESSION_LOSS: [batch]. | |
| - common.MOTION_LOSS: [batch], the frame offset regression loss. | |
| - common.NEXT_REGRESSION_LOSS: [batch], the next regression loss. | |
| Raises: | |
| AssertionError: If the keys of the resulting_dict do not match | |
| self.get_loss_names(). | |
| AssertionError: The keys of the resulting_dict overlap with the keys of | |
| the loss_dict. | |
| """ | |
| resulting_dict = collections.OrderedDict() | |
| # Single-term losses. | |
| for loss_name, func_and_weight in ( | |
| self._single_term_loss_func_and_weight_dict.items()): | |
| loss_func, loss_weight = func_and_weight | |
| loss_value = loss_func(y_true, y_pred) | |
| resulting_dict[loss_name] = loss_value * loss_weight | |
| # Multi-term losses predict a dictionary, so we handle them differently. | |
| for multi_term_loss in self._multi_term_losses: | |
| loss_dict = multi_term_loss((y_true, y_pred)) | |
| if not set(loss_dict).isdisjoint(resulting_dict): | |
| raise AssertionError('The keys of the resulting_dict overlap with the ' | |
| 'keys of the loss_dict.') | |
| resulting_dict.update(loss_dict) | |
| # Also include the total loss in the resulting_dict. | |
| total_loss = tf.math.accumulate_n(list(resulting_dict.values())) | |
| resulting_dict[common.TOTAL_LOSS] = total_loss | |
| if sorted(resulting_dict.keys()) != sorted(self.get_loss_names()): | |
| raise AssertionError( | |
| 'The keys of the resulting_dict should match self.get_loss_names().') | |
| return resulting_dict | |