| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Utilities for models.""" |
| |
|
| | import functools |
| | from typing import Optional, Any, Tuple, Union |
| |
|
| | import flax.linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| |
|
| |
|
| | PyTree = Any |
| | PyModule = Any |
| | Array = Union[jnp.ndarray, np.ndarray] |
| |
|
| |
|
| | def psum_metric_normalizer( |
| | metrics: Tuple[jnp.ndarray, jnp.ndarray], |
| | axis_name: Union[str, Tuple[str, ...]] = 'batch' |
| | ) -> Tuple[jnp.ndarray, jnp.ndarray]: |
| | """Applies psum over the given tuple of (metric, normalizer).""" |
| | psumed_metric = jax.lax.psum(jnp.sum(metrics[0]), axis_name=axis_name) |
| | psumed_normalizer = jax.lax.psum(jnp.sum(metrics[1]), axis_name=axis_name) |
| | return (psumed_metric, psumed_normalizer) |
| |
|
| |
|
| | def num_examples(logits: jnp.ndarray, |
| | one_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None |
| | ) -> Union[jnp.ndarray, int]: |
| | del logits |
| | if weights is None: |
| | return one_hot_targets.shape[0] |
| | return weights.sum() |
| |
|
| |
|
| | def apply_weights(output: jnp.ndarray, weights: jnp.ndarray) -> jnp.ndarray: |
| | """Applies given weights of the inputs in the minibatch to outputs. |
| | |
| | Note that weights can be per example (i.e. of shape `[batch,]`) or per |
| | pixel/token (i.e. of shape `[batch, height, width]` or |
| | `[batch, len]`) so we need to broadcast it to the output shape. |
| | |
| | Args: |
| | output: Computed output, which can be loss or the correctly classified |
| | examples, etc. |
| | weights: Weights of inputs in the batch, which can be None or array of shape |
| | [batch, ...]. |
| | |
| | Returns: |
| | Weighted output. |
| | """ |
| | if output.ndim < weights.ndim: |
| | raise ValueError('Output rank should be higher or equal to weights rank.') |
| | desired_weights_shape = weights.shape + (1,) * (output.ndim - weights.ndim) |
| | weights = jax.lax.broadcast_in_dim( |
| | weights, |
| | shape=desired_weights_shape, |
| | broadcast_dimensions=tuple(range(weights.ndim))) |
| | |
| | return output * weights |
| |
|
| |
|
| | def weighted_correctly_classified( |
| | logits: jnp.ndarray, |
| | one_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None) -> jnp.ndarray: |
| | """Computes weighted number of correctly classified over the given batch. |
| | |
| | This computes the weighted number of correctly classified examples/pixels in a |
| | single, potentially padded minibatch. If the minibatch/inputs is padded (i.e., |
| | it contains null examples/pad pixels) it is assumed that weights is a binary |
| | mask where 0 indicates that the example/pixel is null/padded. We assume the |
| | trainer will aggregate and divide by number of samples. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | one_hot_targets: One hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_targets -1). |
| | |
| | Returns: |
| | The number of correctly classified examples in the given batch. |
| | """ |
| | if logits.ndim != one_hot_targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s one_hot_targets' % |
| | (str(logits.shape), str(one_hot_targets.shape))) |
| | preds = jnp.argmax(logits, axis=-1) |
| | targets = jnp.argmax(one_hot_targets, axis=-1) |
| | correct = jnp.equal(preds, targets) |
| |
|
| | if weights is not None: |
| | correct = apply_weights(correct, weights) |
| |
|
| | return correct.astype(jnp.int32) |
| |
|
| |
|
| | def weighted_top_one_correctly_classified( |
| | logits: jnp.ndarray, |
| | multi_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None) -> jnp.ndarray: |
| | """Computes weighted number of correctly classified, given top 1 class. |
| | |
| | This computes the weighted number of correctly classified examples/pixels in a |
| | single, potentially padded minibatch, given top-one prediction. If the |
| | minibatch/inputs is padded (i.e., it contains null examples/pad pixels) it is |
| | assumed that weights is a binary mask where 0 indicates that the example/pixel |
| | is null/padded. We assume the trainer will aggregate and divide by number of |
| | samples. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_targets: Multi hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_targets -1). |
| | |
| | Returns: |
| | The number of correctly classified examples in the given batch, given top |
| | one prediction. |
| | """ |
| | if logits.ndim != multi_hot_targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s multi_hot_targets' % |
| | (str(logits.shape), str(multi_hot_targets.shape))) |
| |
|
| | top1_idx = jnp.argmax(logits, axis=-1)[..., None] |
| | |
| | top1_correct = jnp.take_along_axis(multi_hot_targets, top1_idx, axis=-1) |
| | if weights is not None: |
| | top1_correct = apply_weights(top1_correct, weights) |
| |
|
| | return top1_correct |
| |
|
| |
|
| | def weighted_topk_correctly_classified(logits: jnp.ndarray, |
| | multi_hot_target: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | k: int = 5) -> jnp.ndarray: |
| | """Computes weighted number of correctly classified given the top k prediction. |
| | |
| | This computes the weighted number of correctly classified examples/pixels in a |
| | single, potentially padded minibatch, given the top-k prediction. In the |
| | multi-hot target case, the sample is considered correct when any of the top-k |
| | predictions matches any of the multi-hot targets. If the minibatch/inputs is |
| | padded (i.e., it contains null examples/pad pixels) it is assumed that weights |
| | is a binary mask where 0 indicates that the example/pixel is null/padded. We |
| | assume the trainer will aggregate and divide by number of |
| | samples. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_target: Multi hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_target -1). |
| | k: Number of top prediction to consider. |
| | |
| | Returns: |
| | The number of correctly classified examples in the given batch, given top |
| | k prediction. |
| | """ |
| | if logits.ndim != multi_hot_target.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s one_hot_target' % |
| | (str(logits.shape), str(multi_hot_target.shape))) |
| | if k <= 0 or k > logits.shape[-1]: |
| | raise ValueError('Incorrect k. k must be in [1,%s]' % |
| | str(logits.shape[-1])) |
| |
|
| | topk_pred = jax.lax.top_k(logits, k)[1] |
| |
|
| | num_classes = logits.shape[-1] |
| | multi_hot_pred = jnp.sum( |
| | jax.nn.one_hot(topk_pred, num_classes=num_classes), axis=-2) |
| | correct = jnp.any( |
| | multi_hot_pred * multi_hot_target, axis=-1, keepdims=True |
| | ).astype(jnp.float32) |
| |
|
| | if weights is not None: |
| | correct = apply_weights(correct, weights) |
| |
|
| | return correct.astype(jnp.int32) |
| |
|
| |
|
| | def weighted_precision_at_k(logits: jnp.ndarray, |
| | multi_hot_target: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | k: int = 5) -> jnp.ndarray: |
| | """Computes fraction of correct predictions among the top k predictions. |
| | |
| | This computes the weighted precision-at-k (i.e. the fraction of true positives |
| | among the top k predicted classes) in a single, potentially padded minibatch. |
| | If the minibatch/inputs is padded (i.e., it contains null examples/pad pixels) |
| | it is assumed that weights is a binary mask where 0 indicates that the |
| | example/pixel is null/padded. We assume the trainer will aggregate and divide |
| | by number of samples. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_target: Multi hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_target -1). |
| | k: Number of top predictions to consider. |
| | |
| | Returns: |
| | The precision for each example in the batch, given top k predictions. |
| | """ |
| | if logits.ndim != multi_hot_target.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s one_hot_target' % |
| | (str(logits.shape), str(multi_hot_target.shape))) |
| | if k <= 0 or k > logits.shape[-1]: |
| | raise ValueError('Incorrect k. k must be in [1,%s]' % |
| | str(logits.shape[-1])) |
| |
|
| | topk_pred = jax.lax.top_k(logits, k)[1] |
| |
|
| | num_classes = logits.shape[-1] |
| | multi_hot_pred = jnp.sum( |
| | jax.nn.one_hot(topk_pred, num_classes=num_classes), axis=-2) |
| |
|
| | true_positive = jnp.sum( |
| | multi_hot_pred * multi_hot_target, axis=-1).astype(jnp.float32) |
| | |
| | |
| | precision = true_positive / k |
| |
|
| | if weights is not None: |
| | precision = apply_weights(precision, weights) |
| |
|
| | return precision |
| |
|
| |
|
| | def weighted_recall(logits: Array, multi_hot_target: Array, |
| | weights: Optional[Array] = None) -> Array: |
| | """Computes weighted recall given the top k prediction. |
| | |
| | This computes the weighted number of correctly recalled examples/pixels in a |
| | single, potentially padded minibatch, given the top-k prediction. Per sample, |
| | k is the number of gt labels in that sample. If the minibatch/inputs is padded |
| | (i.e., it contains null examples/pad pixels) it is assumed that weights is a |
| | binary mask where 0 indicates that the example/pixel is null/padded. We assume |
| | the trainer will aggregate and divide by number of samples. |
| | |
| | Args: |
| | logits: float array; Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_target: Multi hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of multi_hot_target -1). |
| | |
| | Returns: |
| | The fraction of correctly recalled labels. |
| | """ |
| | if logits.ndim != multi_hot_target.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s one_hot_target' % |
| | (str(logits.shape), str(multi_hot_target.shape))) |
| |
|
| | num_classes = multi_hot_target.shape[-1] |
| |
|
| | indices_top = jnp.argsort(logits, axis=-1)[..., ::-1] |
| | predictions_at_top = jax.nn.one_hot(indices_top, num_classes) |
| | correct_at_top = jnp.sum( |
| | predictions_at_top * jnp.expand_dims(multi_hot_target, axis=-2), axis=-1) |
| |
|
| | |
| | |
| | num_gt_labels = jnp.sum(multi_hot_target, axis=-1, keepdims=True) |
| | mask = (num_gt_labels > jnp.arange(num_classes)).astype(jnp.int32) |
| |
|
| | recall = jnp.sum(correct_at_top * mask, axis=-1) / ( |
| | jnp.sum(multi_hot_target, axis=-1) + 1E-12) |
| |
|
| | if weights is not None: |
| | recall = apply_weights(recall, weights) |
| |
|
| | return recall |
| |
|
| |
|
| | def apply_label_smoothing(one_hot_targets: jnp.ndarray, |
| | label_smoothing: Optional[float]) -> jnp.ndarray: |
| | """Apply label smoothing to the one-hot targets. |
| | |
| | Applies label smoothing such that the on-values are transformed from 1.0 to |
| | `1.0 - label_smoothing + label_smoothing / num_classes`, and the off-values |
| | are transformed from 0.0 to `label_smoothing / num_classes`. |
| | https://arxiv.org/abs/1512.00567 |
| | |
| | Note that another way of performing label smoothing (which we don't use here) |
| | is to take `label_smoothing` mass from the on-values and distribute it to the |
| | off-values; in other words, transform the on-values to `1.0 - label_smoothing` |
| | and the off-values to `label_smoothing / (num_classes - 1)`. |
| | http://jmlr.org/papers/v20/18-789.html |
| | |
| | |
| | Args: |
| | one_hot_targets: One-hot targets for an example, a [batch, ..., num_classes] |
| | float array. |
| | label_smoothing: A scalar in [0, 1] used to smooth the labels. |
| | |
| | Returns: |
| | A float array of the same shape as `one_hot_targets` with smoothed label |
| | values. |
| | """ |
| | on_value = 1.0 - label_smoothing |
| | num_classes = one_hot_targets.shape[-1] |
| | off_value = label_smoothing / num_classes |
| | one_hot_targets = one_hot_targets * on_value + off_value |
| | return one_hot_targets |
| |
|
| |
|
| | def weighted_unnormalized_softmax_cross_entropy( |
| | logits: jnp.ndarray, |
| | one_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None, |
| | label_weights: Optional[jnp.ndarray] = None, |
| | logits_normalized: bool = False, |
| | keep_label_dimension: bool = False) -> jnp.ndarray: |
| | """Computes weighted softmax cross entropy give logits and targets. |
| | |
| | This computes sum_(x,y) softmax-ce(x, y) for a single, potentially padded |
| | minibatch. If the minibatch is padded (that is it contains null examples) |
| | it is assumed that weights is a binary mask where 0 indicates that the |
| | example is null. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | one_hot_targets: One hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). |
| | label_smoothing: Scalar to use to smooth the one-hot labels. |
| | label_weights: Weight per label of shape [num_classes]. |
| | logits_normalized: If True, the logits are assumed to already be normalized. |
| | keep_label_dimension: If True, the class dimension of the output loss is not |
| | summed over. |
| | |
| | Returns: |
| | The softmax cross entropy of the examples in the given batch. |
| | """ |
| | if logits.ndim != one_hot_targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s one_hot_targets' % |
| | (str(logits.shape), str(one_hot_targets.shape))) |
| |
|
| | |
| | if label_smoothing is not None: |
| | one_hot_targets = apply_label_smoothing(one_hot_targets, label_smoothing) |
| |
|
| | |
| | if label_weights is not None: |
| | one_hot_targets *= label_weights |
| |
|
| | if not logits_normalized: |
| | logits = nn.log_softmax(logits) |
| | loss = -one_hot_targets * logits |
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| |
|
| | if not keep_label_dimension: |
| | loss = loss.sum(axis=-1) |
| |
|
| | return loss |
| |
|
| |
|
| | def weighted_unnormalized_sigmoid_cross_entropy( |
| | logits: jnp.ndarray, |
| | multi_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None, |
| | logits_normalized: bool = False) -> jnp.ndarray: |
| | """Computes weighted sigmoid cross entropy given logits and targets. |
| | |
| | This also called Binary Cross-Entropy Loss and it measures the probability |
| | error in discrete classification tasks in which each class is independent and |
| | not mutually exclusive. |
| | This computes sum_(x,y) sigmoid-ce(x, y) for a single, potentially padded |
| | minibatch. If the minibatch is padded (that is it contains null examples) |
| | it is assumed that weights is a binary mask where 0 indicates that the |
| | example is null. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_targets: Multi-hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). |
| | This is the weight to apply to the loss computed for each example in the |
| | batch. Can be used to ignore padded examples in the batch. |
| | label_weights: None or array of shape broadcastable to the shape of logits. |
| | Typically this would be [num_classes] and is the weight to apply to each |
| | label. |
| | label_smoothing: Scalar to use to smooth the one-hot labels. |
| | logits_normalized: If True, the logits are assumed to be log probs. |
| | |
| | Returns: |
| | The sigmoid cross entropy of the examples in the given batch. |
| | """ |
| | if logits.ndim != multi_hot_targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s logits and %s multi_hot_targets' % |
| | (str(logits.shape), str(multi_hot_targets.shape))) |
| |
|
| | |
| | if label_smoothing is not None: |
| | multi_hot_targets = apply_label_smoothing(multi_hot_targets, |
| | label_smoothing) |
| |
|
| | if logits_normalized: |
| | log_p, prob = logits, jnp.exp(logits) |
| | log_not_p = jnp.log((1 + 1e-6) - prob) |
| | else: |
| | log_p, log_not_p = jax.nn.log_sigmoid(logits), jax.nn.log_sigmoid(-logits) |
| |
|
| | loss = -(multi_hot_targets * log_p + |
| | (1. - multi_hot_targets) * log_not_p) |
| |
|
| | if label_weights is not None: |
| | loss = loss * label_weights |
| |
|
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| |
|
| | return loss |
| |
|
| |
|
| | def weighted_softmax_cross_entropy( |
| | logits: jnp.ndarray, |
| | one_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None, |
| | label_weights: Optional[jnp.ndarray] = None) -> jnp.ndarray: |
| | """Same as weighted_unnormalized, but additionally takes a mean. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | one_hot_targets: One hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). |
| | label_smoothing: float scalar to use to smooth the one-hot labels. |
| | label_weights: Weight per label of shape [num_classes]. |
| | |
| | Returns: |
| | The mean cross entropy of the examples in the given batch as a scalar. |
| | """ |
| | if weights is not None: |
| | normalization = weights.sum() |
| | else: |
| | normalization = np.prod(one_hot_targets.shape[:-1]) |
| |
|
| | unnormalized_softmax_ce = weighted_unnormalized_softmax_cross_entropy( |
| | logits, one_hot_targets, weights, label_smoothing, label_weights) |
| | return jnp.sum(unnormalized_softmax_ce) / (normalization + 1e-8) |
| |
|
| |
|
| | def weighted_sigmoid_cross_entropy( |
| | logits: jnp.ndarray, |
| | multi_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None) -> jnp.ndarray: |
| | """Computes weighted sigmoid cross entropy given logits and targets. |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_targets: Multi-hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch x ...] (rank of one_hot_targets -1). |
| | label_weights: None or array of shape broadcastable to the shape of logits. |
| | Typically this would be [num_classes] and is the weight to apply to each |
| | label. |
| | label_smoothing: Scalar to use to smooth the one-hot labels. |
| | |
| | Returns: |
| | The mean cross entropy of the examples in the given batch as a scalar. |
| | """ |
| | if weights is not None: |
| | normalization = weights.sum() |
| | else: |
| | normalization = np.prod(multi_hot_targets.shape[:-1]) |
| |
|
| | unnormalized_sigmoid_ce = weighted_unnormalized_sigmoid_cross_entropy( |
| | logits, |
| | multi_hot_targets, |
| | weights=weights, |
| | label_weights=label_weights, |
| | label_smoothing=label_smoothing) |
| | return jnp.sum(unnormalized_sigmoid_ce) / (normalization + 1e-8) |
| |
|
| |
|
| | def l2_regularization(params: PyTree): |
| | """Calculate the L2 loss (square L2 norm), given parameters of the model. |
| | |
| | Args: |
| | params: Parameters of the model. |
| | |
| | Returns: |
| | L2 norm. |
| | |
| | """ |
| | weight_penalty_params = jax.tree_util.tree_leaves(params) |
| | return sum([jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) |
| |
|
| |
|
| | def weighted_l1_loss(x: jnp.ndarray, |
| | y: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | reduction: Optional[str] = None) -> jnp.ndarray: |
| | """L1 loss with optional reduction specified. |
| | |
| | Args: |
| | x: Input array of any shape. |
| | y: Input array of shape broadcastable to that of x. |
| | weights: Weights to apply to the loss. |
| | reduction: Type of reduction, which is from [None, 'mean']. |
| | |
| | Returns: |
| | reduction(jnp.abs(x - y)). 'mean' reduction takes the global mean. To use |
| | customized normalization use 'none' reduction and scale loss in the caller. |
| | """ |
| | abs_diff = jnp.abs(x - y) |
| | if weights is not None: |
| | abs_diff = apply_weights(abs_diff, weights) |
| | if not reduction: |
| | return abs_diff |
| | elif reduction == 'mean': |
| | return abs_diff.mean() |
| |
|
| |
|
| | def weighted_box_l1_loss( |
| | pred: jnp.ndarray, |
| | tgt: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | reduction: Optional[str] = None, |
| | tight: bool = True, |
| | ) -> jnp.ndarray: |
| | """L1 loss for bounding box with optional reduction specified. |
| | |
| | Args: |
| | pred: Prediction boxes of shape (..., 4), where the last dimension has form |
| | (x_min, y_min, x_max, y_max). |
| | tgt: Target boxes of shape (..., 4), where the last dimension has form |
| | (x_min, y_min, x_max, y_max). |
| | weights: Weights to apply to the loss. |
| | reduction: Type of reduction, which is from [None, 'mean']. |
| | tight: If True, returns the vanilla L1 loss on the bounding box coordinates. |
| | If False, returns loose bounding-box L1 loss, where prediction edges only |
| | generate loss when they stretch outside the target box, but not when they |
| | are within it. |
| | |
| | Returns: |
| | reduction(jnp.abs(src - tgt)). 'mean' reduction takes the global mean. To |
| | use customized normalization use 'none' reduction and scale loss in the |
| | caller. |
| | """ |
| | if pred.shape[-1] != 4: |
| | raise ValueError( |
| | f'The last dimension of the prediction boxes must be 4.' |
| | f' Got shape {pred.shape}.' |
| | ) |
| | if tgt.shape[-1] != 4: |
| | raise ValueError( |
| | f'The last dimension of the target boxes must be 4.' |
| | f' Got shape {tgt.shape}.' |
| | ) |
| | if tight: |
| | abs_diff = jnp.abs(pred - tgt) |
| | else: |
| | xy1, xy2 = jnp.split(pred - tgt, 2, axis=-1) |
| | xy1 = jnp.minimum(xy1, 0.) |
| | xy2 = jnp.maximum(xy2, 0.) |
| | abs_diff = jnp.abs(jnp.concatenate([xy1, xy2], axis=-1)) |
| | if weights is not None: |
| | abs_diff = apply_weights(abs_diff, weights) |
| | if not reduction: |
| | return abs_diff |
| | elif reduction == 'mean': |
| | return abs_diff.mean() |
| | else: |
| | raise ValueError(f'Unknown reduction: {reduction}') |
| |
|
| |
|
| | |
| |
|
| |
|
| | def weighted_squared_error( |
| | predictions: jnp.ndarray, |
| | targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: |
| | """Computes weighted squared error given predictions and targets. |
| | |
| | This computes the squared_error of examples in a single, potentially |
| | padded minibatch. If the minibatch is padded (that is it contains null |
| | examples) it is assumed that weights is a binary mask where 0 indicates that |
| | the example is null. |
| | |
| | Args: |
| | predictions: Output of model in shape shape [batch, ..., n_features]. |
| | targets: Array of shape [batch, ..., n_features]. |
| | weights: None or array of shape [batch, ...]. This is the weight to apply |
| | to the loss computed for each example in the batch. Can be used to ignore |
| | padded examples in the batch. |
| | axis: The axis (or axes) to compute the loss over. If not specified, all |
| | dimensions besides the leading batch dimension are used. |
| | |
| | Returns: |
| | The mean squared error for each example in the given batch. The output shape |
| | depends on axis. |
| | """ |
| | if predictions.ndim != targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s predictions and %s targets' % |
| | (str(predictions.shape), str(targets.shape))) |
| | if axis is None: |
| | |
| | axis = tuple(range(1, predictions.ndim)) |
| |
|
| | error = targets - predictions |
| | loss = jnp.square(error) |
| | loss = jnp.sum(loss, axis=axis) |
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| | return loss |
| |
|
| |
|
| | def weighted_mean_squared_error( |
| | predictions: jnp.ndarray, |
| | targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: |
| | """Weighted mean of weighted_squared_error. |
| | |
| | Args: |
| | predictions: Output of model in shape [batch, ..., num_features]. |
| | targets: Targets of shape [batch, ..., num_features]. |
| | weights: None or array of shape [batch,] This is the weight to apply to the |
| | loss computed for each example in the batch. Can be used to ignore padded |
| | examples in the batch. |
| | axis: The axis (or axes) to compute the loss over. If not specified, all |
| | dimensions besides the leading batch dimension are used. |
| | |
| | Returns: |
| | The averaged mean squared error of all the examples in the given batch as a |
| | scalar. |
| | """ |
| | unnormalized_mse = weighted_squared_error( |
| | predictions=predictions, targets=targets, weights=weights, axis=axis) |
| |
|
| | if weights is not None: |
| | |
| | broadcasted_shape = weights.shape + (1,) * ( |
| | unnormalized_mse.ndim - weights.ndim) |
| | broadcasted_weights = jax.lax.broadcast_in_dim( |
| | weights, |
| | shape=broadcasted_shape, |
| | broadcast_dimensions=tuple(range(weights.ndim))) |
| | normalization = jnp.sum(broadcasted_weights * |
| | jnp.ones(unnormalized_mse.shape)) |
| | else: |
| | |
| | normalization = unnormalized_mse.size |
| | return jnp.sum(unnormalized_mse) / (normalization + 1e-8) |
| |
|
| |
|
| | def weighted_absolute_error( |
| | predictions: jnp.ndarray, |
| | targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: |
| | """Computes weighted absolute error given predictions and targets. |
| | |
| | This computes the absolute_error of examples in a single, potentially |
| | padded minibatch. If the minibatch is padded (that is it contains null |
| | examples) it is assumed that weights is a binary mask where 0 indicates that |
| | the example is null. |
| | |
| | Args: |
| | predictions: Output of model in shape shape [batch, ..., n_features]. |
| | targets: Array of shape [batch, ..., n_features]. |
| | weights: None or array of shape [batch, ...] This is the weight to apply to |
| | the loss computed for each example in the batch. Can be used to ignore |
| | padded examples in the batch. |
| | axis: The axis (or axes) to compute the loss over. If not specified, all |
| | dimensions besides the leading batch dimension are used. |
| | |
| | Returns: |
| | The mean absolute error for each example in the given batch. The output |
| | shape depends on axis. |
| | """ |
| | if predictions.ndim != targets.ndim: |
| | raise ValueError( |
| | 'Incorrect shapes. Got shape %s predictions and %s targets' % |
| | (str(predictions.shape), str(targets.shape))) |
| | if axis is None: |
| | |
| | axis = tuple(range(1, predictions.ndim)) |
| |
|
| | error = targets - predictions |
| | loss = jnp.absolute(error) |
| | |
| | loss = jnp.sum(loss, axis=axis) |
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| | return loss |
| |
|
| |
|
| | def weighted_mean_absolute_error( |
| | predictions: jnp.ndarray, |
| | targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: |
| | """Weighted mean of weighted_unnormalized_mean_absolute_error. |
| | |
| | Args: |
| | predictions: Output of model in shape [batch, ..., num_features]. |
| | targets: Targets of shape [batch, ..., num_features]. |
| | weights: None or array of shape [batch, ...]. This is the weight to apply |
| | to the loss computed for each example in the batch. Can be used to ignore |
| | padded examples in the batch. |
| | axis: The axis (or axes) to compute the loss over. If not specified, all |
| | dimensions besides the leading batch dimension are used. |
| | |
| | Returns: |
| | The averaged mean absolute error of all the examples in the given batch as |
| | a scalar. |
| | """ |
| | unnormalized_mae = weighted_absolute_error( |
| | predictions=predictions, targets=targets, weights=weights, axis=axis) |
| |
|
| | if weights is not None: |
| | |
| | normalization = weights.sum() |
| | else: |
| | |
| | normalization = unnormalized_mae.shape[0] |
| | return jnp.sum(unnormalized_mae) / (normalization + 1e-8) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def focal_softmax_cross_entropy( |
| | logits: jnp.ndarray, |
| | one_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None, |
| | label_weights: Optional[jnp.ndarray] = None, |
| | logits_normalized: bool = False, |
| | gamma: Optional[float] = 2.0, |
| | keep_label_dimension: bool = False) -> jnp.ndarray: |
| | """Computes focal softmax cross-entropy given logits and targets. |
| | |
| | Focal loss as defined in https://arxiv.org/abs/1708.02002. Assuming y is the |
| | target vector and p is the predicted probability for the class, then: |
| | |
| | p_t = p if y == 1 and 1-p otherwise |
| | Focal loss = -(1-p_t)**gamma * log(p_t) |
| | |
| | NOTE: this is weighted unnormalized computation of loss that returns the loss |
| | of examples in the batch. If you are using it as a loss function, you can |
| | use the normalilzed version as: |
| | ``` |
| | unnormalized_loss = focal_softmax_cross_entropy(...) |
| | if weights is not None: |
| | normalization = weights.sum() |
| | else: |
| | normalization = np.prod(one_hot_targets.shape[:-1]) |
| | loss = jnp.sum(unnormalized_loss) / (normalization + 1e-8) |
| | ``` |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | one_hot_targets: One hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_targets -1). |
| | label_smoothing: Scalar to use to smooth the one-hot labels. |
| | label_weights: Weight per label of shape [num_classes]. |
| | logits_normalized: If True, the logits are assumed to be log probs. |
| | gamma: Modulating factor of the focal loss. |
| | keep_label_dimension: If True, the class dimension of the output loss is not |
| | summed over. |
| | |
| | Returns: |
| | The loss of the examples in the given batch. |
| | """ |
| | loss = weighted_unnormalized_softmax_cross_entropy( |
| | logits, one_hot_targets, weights=None, label_smoothing=label_smoothing, |
| | label_weights=label_weights, logits_normalized=logits_normalized, |
| | keep_label_dimension=True) |
| | prob = jnp.exp(logits) if logits_normalized else jax.nn.softmax(logits) |
| | prob = (prob * one_hot_targets).sum(axis=-1, keepdims=True) |
| | loss *= (1. - prob)**gamma |
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| |
|
| | if not keep_label_dimension: |
| | loss = loss.sum(axis=-1) |
| |
|
| | return loss |
| |
|
| |
|
| | def focal_sigmoid_cross_entropy( |
| | logits: jnp.ndarray, |
| | multi_hot_targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | label_smoothing: Optional[float] = None, |
| | label_weights: Optional[jnp.ndarray] = None, |
| | logits_normalized: bool = False, |
| | alpha: Optional[float] = 0.5, |
| | gamma: Optional[float] = 2.0) -> jnp.ndarray: |
| | """Computes focal softmax cross-entropy given logits and targets. |
| | |
| | Focal loss as defined in https://arxiv.org/abs/1708.02002. Assuming y is the |
| | target vector and p is the predicted probability for the class, then: |
| | |
| | p_t = p if y == 1 and 1-p otherwise |
| | alpha_t = alpha if y == 1 and 1-alpha otherwise |
| | |
| | Focal loss = -alpha_t * (1-p_t)**gamma * log(p_t) |
| | |
| | NOTE: this is weighted unnormalized computation of loss that returns the loss |
| | of examples in the batch. If you are using it as a loss function, you can |
| | use the normalilzed version as: |
| | ``` |
| | unnormalized_loss = focal_sigmoid_cross_entropy(...) |
| | if weights is not None: |
| | normalization = weights.sum() |
| | else: |
| | normalization = np.prod(multi_hot_targets.shape[:-1]) |
| | loss = jnp.sum(unnormalized_loss) / (normalization + 1e-8) |
| | ``` |
| | |
| | Args: |
| | logits: Output of model in shape [batch, ..., num_classes]. |
| | multi_hot_targets: Multi-hot vector of shape [batch, ..., num_classes]. |
| | weights: None or array of shape [batch, ...] (rank of one_hot_targets -1). |
| | label_smoothing: Scalar to use to smooth the one-hot labels. |
| | label_weights: Weight per label of shape [num_classes]. |
| | logits_normalized: If True, the logits are assumed to be log probs. |
| | alpha: Balancing factor of the focal loss. |
| | gamma: Modulating factor of the focal loss. |
| | |
| | Returns: |
| | The loss of the examples in the given batch. |
| | """ |
| | |
| | if label_smoothing is not None: |
| | multi_hot_targets = apply_label_smoothing(multi_hot_targets, |
| | label_smoothing) |
| | if logits_normalized: |
| | log_p, prob = logits, jnp.exp(logits) |
| | log_not_p = jnp.log((1 + 1e-6) - prob) |
| | else: |
| | log_p, log_not_p = jax.nn.log_sigmoid(logits), jax.nn.log_sigmoid(-logits) |
| |
|
| | loss = -(multi_hot_targets * log_p + (1. - multi_hot_targets) * log_not_p) |
| |
|
| | p_t = jnp.exp(-loss) |
| | loss *= (1 - p_t)**gamma |
| | loss *= alpha * multi_hot_targets + (1 - alpha) * (1 - multi_hot_targets) |
| |
|
| | if label_weights is not None: |
| | loss = loss * label_weights |
| |
|
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| | return loss |
| |
|
| |
|
| | |
| |
|
| |
|
| | @functools.partial(jax.vmap, in_axes=[0, 0], out_axes=0) |
| | def simple_gather(x: jnp.ndarray, idx: jnp.ndarray) -> jnp.ndarray: |
| | """Gathers `x` using the indices in `idx`. |
| | |
| | `output[i] = x[i, idx[i]]` . This simple gather operation assumes that the |
| | first dimension is the batch dimension. The indices index into the second |
| | dimension. The rest of the dimensions are copied as is from `x` into output. |
| | Note that the implementation below only handles a single element in the batch. |
| | `jax.vmap` extends this to the batch dimension. |
| | |
| | Args: |
| | x: Inputs of shape [bs, n, d]. |
| | idx: An array of shape [bs, m] and dtype jnp.int32 or int64 that specifies |
| | indexes we want to gather from x. |
| | |
| | Returns: |
| | Gathered output of shape [bs, m, d]. |
| | """ |
| | return x[idx] |
| |
|
| |
|
| | def confusion_matrix(y_true: Array, |
| | y_pred: Array, |
| | num_classes: int, |
| | weights: Optional[Array] = None, |
| | np_backbone: PyModule = jnp) -> Array: |
| | """Computes the confusion matrix between y_true and y_pred. |
| | |
| | Args: |
| | y_true: Array of true labels. |
| | y_pred: Array of predicted labels. |
| | num_classes: Number of classes. |
| | weights: nd-array, Weight of each datapoint (e.g. for masking). |
| | np_backbone: numpy module: Either the regular numpy package or jax.numpy. |
| | |
| | Returns: |
| | A [num_classes, num_classes] confusion matrix, normalized by the number of |
| | elements in y_true/y_pred. |
| | """ |
| | assert y_true.shape == y_pred.shape |
| | if weights is None: |
| | weights = np_backbone.ones_like(y_true) |
| | else: |
| | assert y_true.shape == weights.shape |
| |
|
| | |
| | |
| | weights_all_zero = 1.0 - np_backbone.any(weights).astype(np_backbone.float32) |
| | weights = weights + weights_all_zero |
| |
|
| | cm, *_ = np_backbone.histogram2d( |
| | y_true.ravel(), |
| | y_pred.ravel(), |
| | bins=np_backbone.arange(num_classes + 1), |
| | weights=None if weights is None else weights.ravel()) |
| |
|
| | |
| | cm = cm * (1.0 - weights_all_zero) |
| | return cm |
| |
|
| |
|
| | def mean_iou(cm: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | """Computes the mean intersection-over-union, given a confusion matrix. |
| | |
| | Args: |
| | cm: array_like; [num_classes, num_classes] confusion matrix. |
| | |
| | Returns: |
| | Scalar mean intersection-over-union score. |
| | """ |
| | |
| | |
| |
|
| | sum_over_row = np.sum(cm, axis=0) |
| | sum_over_col = np.sum(cm, axis=1) |
| | true_positives = np.diag(cm) |
| |
|
| | |
| | |
| | denominator = sum_over_row + sum_over_col - true_positives |
| |
|
| | |
| | |
| | |
| | iou_per_class = true_positives / denominator |
| | return (np.nan_to_num(np.nanmean(iou_per_class)), |
| | np.nan_to_num(iou_per_class)) |
| |
|
| |
|
| | def dice_loss(inputs: jnp.ndarray, |
| | targets: jnp.ndarray, |
| | weights: Optional[jnp.ndarray] = None, |
| | all_pairs: bool = False, |
| | eps: float = 1.0, |
| | interpolation: str = 'nearest') -> jnp.ndarray: |
| | """Computes the Dice loss given panoptic segmentation logits and targets. |
| | |
| | This loss is based on the Dice coefficient (F-1 score). For details, see |
| | https://arxiv.org/abs/2005.12872 and https://arxiv.org/pdf/1606.04797.pdf. |
| | |
| | Args: |
| | inputs: Predicted mask logits with shape [batch, num_objects, H, W]. |
| | targets: Target masks with shape [batch, num_objects, H, W]. |
| | weights: Array of shape [batch, ...]. |
| | all_pairs: Whether to compute the loss for all object pairs or not. |
| | eps: Epsilon for numerical stability. |
| | interpolation: Method to use for upsampling inputs to target size. |
| | |
| | Returns: |
| | If all_pairs == True, returns a [bs, n, m] pairwise matrix, of dice loss. |
| | If all_pairs == False, returns a [bs, n] matrix of dice loss. |
| | """ |
| | _, n, h, w = inputs.shape |
| | b, m, _, _ = targets.shape |
| |
|
| | |
| | |
| | |
| | targets = jax.image.resize( |
| | targets, shape=[b, m, h, w], method=interpolation, antialias=True) |
| |
|
| | |
| | |
| | inputs = jax.nn.sigmoid(inputs) |
| |
|
| | inputs = jnp.reshape(inputs, [b, n, h * w]) |
| | targets = jnp.reshape(targets, [b, m, h * w]) |
| | if all_pairs: |
| | numerator = 2 * jnp.einsum('bnp,bkp->bnk', inputs, targets) |
| | denominator = (jnp.sum(inputs[:, :, None, :], axis=-1) + |
| | jnp.sum(targets[:, None, :, :], axis=-1)) |
| | else: |
| | assert n == m |
| | numerator = 2 * jnp.einsum('bnp,bnp->bn', inputs, targets) |
| | denominator = jnp.sum(inputs + targets, axis=-1) |
| | loss = 1.0 - (numerator + eps) / (denominator + eps) |
| |
|
| | if weights is not None: |
| | loss = apply_weights(loss, weights) |
| |
|
| | return loss |
| |
|