File size: 10,986 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base class for all semantic segmentation models."""

import functools
from typing import Any, Callable, List, Dict, Optional, Tuple, Union

from flax.training import common_utils
from immutabledict import immutabledict
import jax.numpy as jnp
import numpy as np
from scenic.model_lib.base_models import base_model
from scenic.model_lib.base_models import model_utils


GlobalMetricFn = Callable[[List[jnp.ndarray], Dict[str, Any]], Dict[str, float]]


def num_pixels(logits: jnp.ndarray,
               one_hot_targets: jnp.ndarray,
               weights: Optional[jnp.ndarray] = None) -> float:
  """Computes number of pixels in the target to be used for normalization.

  It needs to have the same API as other defined metrics.

  Args:
    logits: Unused.
    one_hot_targets: Targets, in form of one-hot vectors.
    weights: Input weights (can be used for accounting the padding in the
      input).

  Returns:
    Number of (non-padded) pixels in the input.
  """
  del logits
  if weights is None:
    return np.prod(one_hot_targets.shape[:3])
  assert weights.ndim == 3, (
      'For segmentation task, the weights should be a pixel level mask.')
  return weights.sum()  # pytype: disable=bad-return-type  # jax-ndarray


# Standard default metrics for the semantic segmentation models.
_SEMANTIC_SEGMENTATION_METRICS = immutabledict({
    'accuracy': (model_utils.weighted_correctly_classified, num_pixels),

    # The loss is already normalized, so we set num_pixels to 1.0:
    'loss': (model_utils.weighted_softmax_cross_entropy, lambda *a, **kw: 1.0)
})


def semantic_segmentation_metrics_function(
    logits: jnp.ndarray,
    batch: base_model.Batch,
    target_is_onehot: bool = False,
    metrics: base_model.MetricNormalizerFnDict = _SEMANTIC_SEGMENTATION_METRICS,
    axis_name: Union[str, Tuple[str, ...]] = 'batch',
) -> Dict[str, Tuple[jnp.ndarray, jnp.ndarray]]:
  """Calculates metrics for the semantic segmentation task.

  Currently we assume each metric_fn has the API:
    ```metric_fn(logits, targets, weights)```
  and returns an array of shape [batch_size]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen. In this way we currently only support metrics of the 1/N
  sum f(inputs, targets). Note, the caller is responsible for dividing by
  the normalizer when computing the mean of each metric.

  Args:
   logits: Output of model in shape [batch, length, num_classes].
   batch: Batch of data that has 'label' and optionally 'batch_mask'.
   target_is_onehot: If the target is a one-hot vector.
   metrics: The semantic segmentation metrics to evaluate. The key is the name
    of the metric, and the value is the metrics function.
   axis_name: List of axes on which we run the pmsum.

  Returns:
    A dict of metrics, in which keys are metrics name and values are tuples of
    (metric, normalizer).
  """
  if target_is_onehot:
    one_hot_targets = batch['label']
  else:
    one_hot_targets = common_utils.onehot(batch['label'], logits.shape[-1])
  weights = batch.get('batch_mask')  # batch_mask might not be defined

  # This psum is required to correctly evaluate with multihost. Only host 0
  # will report the metrics, so we must aggregate across all hosts. The psum
  # will map an array of shape [n_global_devices, batch_size] -> [batch_size]
  # by summing across the devices dimension. The outer sum then sums across the
  # batch dim. The result is then we have summed across all samples in the
  # sharded batch.
  evaluated_metrics = {}
  for key, val in metrics.items():
    evaluated_metrics[key] = model_utils.psum_metric_normalizer(  # pytype: disable=wrong-arg-types  # jax-ndarray
        (val[0](logits, one_hot_targets, weights), val[1](  # pytype: disable=wrong-arg-types  # jax-types
            logits, one_hot_targets, weights)),
        axis_name=axis_name)
  return evaluated_metrics


class SegmentationModel(base_model.BaseModel):
  """Defines commonalities between all semantic segmentation models.

  A model is class with three members: get_metrics_fn, loss_fn, and a
  flax_model.

  get_metrics_fn returns a callable function, metric_fn, that calculates the
  metrics and returns a dictionary. The metric function computes f(x_i, y_i) on
  a minibatch, it has API:
    ```metric_fn(logits, label, weights).```

  The trainer will then aggregate and compute the mean across all samples
  evaluated.

  loss_fn is a function of API
    loss = loss_fn(logits, batch, model_params=None).

  This model class defines a softmax_cross_entropy_loss with weight decay,
  where the weight decay factor is determined by config.l2_decay_factor.

  flax_model is returned from the build_flax_model function. A typical
  usage pattern will be:
    ```
    model_cls = model_lib.models.get_model_cls('simple_cnn_segmentation')
    model = model_cls(config, dataset.meta_data)
    flax_model = model.build_flax_model
    dummy_input = jnp.zeros(input_shape, model_input_dtype)
    model_state, params = flax_model.init(
        rng, dummy_input, train=False).pop('params')
    ```
  And this is how to call the model:
    variables = {'params': params, **model_state}
    logits, new_model_state = flax_model.apply(variables, inputs, ...)
    ```
  """

  def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn:
    """Returns a callable metric function for the model.

    Args:
      split: The split for which we calculate the metrics. It should be one of
        the ['train',  'validation', 'test'].
    Returns: A metric function with the following API: ```metrics_fn(logits,
      batch)```
    """
    del split  # For all splits, we return the same metric functions.
    return functools.partial(
        semantic_segmentation_metrics_function,
        target_is_onehot=self.dataset_meta_data.get('target_is_onehot', False),
        metrics=_SEMANTIC_SEGMENTATION_METRICS)

  def loss_function(self,
                    logits: jnp.ndarray,
                    batch: base_model.Batch,
                    model_params: Optional[jnp.ndarray] = None) -> float:
    """Returns softmax cross entropy loss with an L2 penalty on the weights.

    Args:
      logits: Output of model in shape [batch, length, num_classes].
      batch: Batch of data that has 'label' and optionally 'batch_mask'.
      model_params: Parameters of the model, for optionally applying
        regularization.

    Returns:
      Total loss.
    """
    weights = batch.get('batch_mask')

    if self.dataset_meta_data.get('target_is_onehot', False):
      one_hot_targets = batch['label']
    else:
      one_hot_targets = common_utils.onehot(batch['label'], logits.shape[-1])

    sof_ce_loss = model_utils.weighted_softmax_cross_entropy(
        logits,
        one_hot_targets,
        weights,
        label_smoothing=self.config.get('label_smoothing'),
        label_weights=self.get_label_weights())
    if self.config.get('l2_decay_factor') is None:
      total_loss = sof_ce_loss
    else:
      l2_loss = model_utils.l2_regularization(model_params)
      total_loss = sof_ce_loss + 0.5 * self.config.l2_decay_factor * l2_loss
    return total_loss  # pytype: disable=bad-return-type  # jax-ndarray

  def get_label_weights(self) -> jnp.ndarray:
    """Returns labels' weights to be used for computing weighted loss.

    This can used for weighting the loss terms based on the amount of available
    data for each class, when we have un-balances data for different classes.
    """
    if not self.config.get('class_rebalancing_factor'):
      return None  # pytype: disable=bad-return-type  # jax-ndarray
    if 'class_proportions' not in self.dataset_meta_data:
      raise ValueError(
          'When `class_rebalancing_factor` is nonzero, `class_proportions` must'
          ' be provided in `dataset_meta_data`.')
    w = self.config.get('class_rebalancing_factor')
    assert 0.0 <= w <= 1.0, '`class_rebalancing_factor` must be in [0.0, 1.0]'
    proportions = self.dataset_meta_data['class_proportions']
    proportions = np.maximum(proportions / np.sum(proportions), 1e-8)
    # Interpolate between no rebalancing (w==0.0) and full reweighting (w==1.0):
    proportions = w * proportions + (1.0 - w)
    weights = 1.0 / proportions
    weights /= np.sum(weights)  # Normalize so weights sum to 1.
    weights *= len(weights)  # Scale so weights sum to num_classes.
    return weights

  def get_global_metrics_fn(self) -> GlobalMetricFn:
    """Returns a callable metric function for global metrics.

      The return function implements metrics that require the prediction for the
      entire test/validation dataset in one place and has the following API:
        ```global_metrics_fn(all_confusion_mats, dataset_metadata)```
      If return None, no global metrics will be computed.
    """
    return global_metrics_fn

  def build_flax_model(self):
    raise NotImplementedError('Subclasses must implement build_flax_model().')

  def default_flax_model_config(self):
    """Default config for the flax model that is built in `build_flax_model`.

    This function in particular serves the testing functions and supposed to
    provide config tha are passed to the flax_model when it's build in
    `build_flax_model` function, e.g., `model_dtype_str`.
    """
    raise NotImplementedError(
        'Subclasses must implement default_flax_model_config().')


def global_metrics_fn(all_confusion_mats: List[jnp.ndarray],
                      dataset_metadata: Dict[str, Any]) -> Dict[str, float]:
  """Returns a dict with global (whole-dataset) metrics."""
  # Compute mIoU from list of confusion matrices:
  assert isinstance(all_confusion_mats, list)  # List of eval batches.
  cm = np.sum(all_confusion_mats, axis=0)  # Sum over eval batches.
  assert cm.ndim == 3, ('Expecting confusion matrix to have shape '
                        '[batch_size, num_classes, num_classes], got '
                        f'{cm.shape}.')
  cm = np.sum(cm, axis=0)  # Sum over batch dimension.
  mean_iou, iou_per_class = model_utils.mean_iou(cm)
  metrics_dict = {'mean_iou': float(mean_iou)}
  for label, iou in enumerate(iou_per_class):
    tag = f'iou_per_class/{label:02.0f}'
    if 'class_names' in dataset_metadata:
      tag = f"{tag}_{dataset_metadata['class_names'][label]}"
    metrics_dict[tag] = float(iou)
  return metrics_dict