Spaces:
Paused
Paused
| # Copyright 2022 Google LLC | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Evaluation library for frame interpolation.""" | |
| from typing import Dict, Mapping, Text | |
| from absl import logging | |
| import tensorflow as tf | |
| def _collect_tensors(tensors: tf.Tensor) -> tf.Tensor: | |
| """Collect tensors of the different replicas into a list.""" | |
| return tf.nest.flatten(tensors, expand_composites=True) | |
| def _distributed_eval_step(strategy: tf.distribute.Strategy, | |
| batch: Dict[Text, tf.Tensor], model: tf.keras.Model, | |
| metrics: Dict[Text, tf.keras.metrics.Metric], | |
| checkpoint_step: int) -> Dict[Text, tf.Tensor]: | |
| """Distributed eval step. | |
| Args: | |
| strategy: A Tensorflow distribution strategy. | |
| batch: A batch of training examples. | |
| model: The Keras model to evaluate. | |
| metrics: The Keras metrics used for evaluation (a dictionary). | |
| checkpoint_step: The iteration number at which the checkpoint is restored. | |
| Returns: | |
| list of predictions from each replica. | |
| """ | |
| def _eval_step( | |
| batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: | |
| """Eval for one step.""" | |
| predictions = model(batch, training=False) | |
| # Note: these metrics expect batch and prediction dictionaries rather than | |
| # tensors like standard TF metrics do. This allows our losses and metrics to | |
| # use a richer set of inputs than just the predicted final image. | |
| for metric in metrics.values(): | |
| metric.update_state(batch, predictions, checkpoint_step=checkpoint_step) | |
| return predictions | |
| return strategy.run(_eval_step, args=(batch,)) | |
| def _summarize_image_tensors(combined, prefix, step): | |
| for name in combined: | |
| image = combined[name] | |
| if isinstance(image, tf.Tensor): | |
| if len(image.shape) == 4 and (image.shape[-1] == 1 or | |
| image.shape[-1] == 3): | |
| tf.summary.image(prefix + '/' + name, image, step=step) | |
| def eval_loop(strategy: tf.distribute.Strategy, | |
| eval_base_folder: str, | |
| model: tf.keras.Model, | |
| metrics: Dict[str, tf.keras.metrics.Metric], | |
| datasets: Mapping[str, tf.data.Dataset], | |
| summary_writer: tf.summary.SummaryWriter, | |
| checkpoint_step: int): | |
| """Eval function that is strategy agnostic. | |
| Args: | |
| strategy: A Tensorflow distributed strategy. | |
| eval_base_folder: A path to where the summaries event files and | |
| checkpoints will be saved. | |
| model: A function that returns the model. | |
| metrics: A function that returns the metrics dictionary. | |
| datasets: A dict of tf.data.Dataset to evaluate on. | |
| summary_writer: Eval summary writer. | |
| checkpoint_step: The number of iterations completed. | |
| """ | |
| logging.info('Saving eval summaries to: %s...', eval_base_folder) | |
| summary_writer.set_as_default() | |
| for dataset_name, dataset in datasets.items(): | |
| for metric in metrics.values(): | |
| metric.reset_states() | |
| logging.info('Loading %s testing data ...', dataset_name) | |
| dataset = strategy.experimental_distribute_dataset(dataset) | |
| logging.info('Evaluating %s ...', dataset_name) | |
| batch_idx = 0 | |
| max_batches_to_summarize = 10 | |
| for batch in dataset: | |
| predictions = _distributed_eval_step(strategy, batch, model, metrics, | |
| checkpoint_step) | |
| # Clip interpolator output to [0,1]. Clipping is done only | |
| # on the eval loop to get better metrics, but not on the training loop | |
| # so gradients are not killed. | |
| if strategy.num_replicas_in_sync > 1: | |
| predictions = { | |
| 'image': tf.concat(predictions['image'].values, axis=0) | |
| } | |
| predictions['image'] = tf.clip_by_value(predictions['image'], 0., 1.) | |
| if batch_idx % 10 == 0: | |
| logging.info('Evaluating batch %s', batch_idx) | |
| batch_idx = batch_idx + 1 | |
| if batch_idx < max_batches_to_summarize: | |
| # Loop through the global batch: | |
| prefix = f'{dataset_name}/eval_{batch_idx}' | |
| # Find all tensors that look like images, and summarize: | |
| combined = {**batch, **predictions} | |
| _summarize_image_tensors(combined, prefix, step=checkpoint_step) | |
| elif batch_idx == max_batches_to_summarize: | |
| tf.summary.flush() | |
| for name, metric in metrics.items(): | |
| tf.summary.scalar( | |
| f'{dataset_name}/{name}', metric.result(), step=checkpoint_step) | |
| tf.summary.flush() | |
| logging.info('Step {:2}, {} {}'.format(checkpoint_step, | |
| f'{dataset_name}/{name}', | |
| metric.result().numpy())) | |
| metric.reset_states() | |