| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| """A library for instantiating frame interpolation evaluation metrics.""" |
|
|
| from typing import Callable, Dict, Text |
|
|
| from ..losses import losses |
| import tensorflow as tf |
|
|
|
|
| class TrainLossMetric(tf.keras.metrics.Metric): |
| """Compute training loss for our example and prediction format. |
| |
| The purpose of this is to ensure that we always include a loss that is exactly |
| like the training loss into the evaluation in order to detect possible |
| overfitting. |
| """ |
|
|
| def __init__(self, name='eval_loss', **kwargs): |
| super(TrainLossMetric, self).__init__(name=name, **kwargs) |
| self.acc = self.add_weight(name='train_metric_acc', initializer='zeros') |
| self.count = self.add_weight(name='train_metric_count', initializer='zeros') |
|
|
| def update_state(self, |
| batch, |
| predictions, |
| sample_weight=None, |
| checkpoint_step=0): |
| loss_functions = losses.training_losses() |
| loss_list = [] |
| for (loss_value, loss_weight) in loss_functions.values(): |
| loss_list.append( |
| loss_value(batch, predictions) * loss_weight(checkpoint_step)) |
| loss = tf.add_n(loss_list) |
| self.acc.assign_add(loss) |
| self.count.assign_add(1) |
|
|
| def result(self): |
| return self.acc / self.count |
|
|
| def reset_states(self): |
| self.acc.assign(0) |
| self.count.assign(0) |
|
|
|
|
| class L1Metric(tf.keras.metrics.Metric): |
| """Compute L1 over our training example and prediction format. |
| |
| The purpose of this is to ensure that we have at least one metric that is |
| compatible across all eval the session and allows us to quickly compare models |
| against each other. |
| """ |
|
|
| def __init__(self, name='eval_loss', **kwargs): |
| super(L1Metric, self).__init__(name=name, **kwargs) |
| self.acc = self.add_weight(name='l1_metric_acc', initializer='zeros') |
| self.count = self.add_weight(name='l1_metric_count', initializer='zeros') |
|
|
| def update_state(self, batch, prediction, sample_weight=None, |
| checkpoint_step=0): |
| self.acc.assign_add(losses.l1_loss(batch, prediction)) |
| self.count.assign_add(1) |
|
|
| def result(self): |
| return self.acc / self.count |
|
|
| def reset_states(self): |
| self.acc.assign(0) |
| self.count.assign(0) |
|
|
|
|
| class GenericLossMetric(tf.keras.metrics.Metric): |
| """Metric based on any loss function.""" |
|
|
| def __init__(self, name: str, loss: Callable[..., tf.Tensor], |
| weight: Callable[..., tf.Tensor], **kwargs): |
| """Initializes a metric based on a loss function and a weight schedule. |
| |
| Args: |
| name: The name of the metric. |
| loss: The callable loss that calculates a loss value for a (prediction, |
| target) pair. |
| weight: The callable weight scheduling function that samples a weight |
| based on iteration. |
| **kwargs: Any additional keyword arguments to be passed. |
| """ |
| super(GenericLossMetric, self).__init__(name=name, **kwargs) |
| self.acc = self.add_weight(name='loss_metric_acc', initializer='zeros') |
| self.count = self.add_weight(name='loss_metric_count', initializer='zeros') |
| self.loss = loss |
| self.weight = weight |
|
|
| def update_state(self, |
| batch, |
| predictions, |
| sample_weight=None, |
| checkpoint_step=0): |
| self.acc.assign_add( |
| self.loss(batch, predictions) * self.weight(checkpoint_step)) |
| self.count.assign_add(1) |
|
|
| def result(self): |
| return self.acc / self.count |
|
|
| def reset_states(self): |
| self.acc.assign(0) |
| self.count.assign(0) |
|
|
|
|
| def create_metrics_fn() -> Dict[Text, tf.keras.metrics.Metric]: |
| """Create evaluation metrics. |
| |
| L1 and total training loss are added by default. |
| The rest are the configured by the test_losses item via gin. |
| |
| Returns: |
| A dictionary from metric name to Keras Metric object. |
| """ |
| metrics = {} |
| |
| |
| metrics['l1'] = L1Metric() |
| |
| metrics['training_loss'] = TrainLossMetric() |
|
|
| test_losses = losses.test_losses() |
| for loss_name, (loss_value, loss_weight) in test_losses.items(): |
| metrics[loss_name] = GenericLossMetric( |
| name=loss_name, loss=loss_value, weight=loss_weight) |
| return metrics |
|
|