Spaces:
Paused
Paused
| # Copyright 2022 Google LLC | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| r"""Training library for frame interpolation using distributed strategy.""" | |
| import functools | |
| from typing import Any, Callable, Dict, Text, Tuple | |
| from absl import logging | |
| import tensorflow as tf | |
| def _concat_tensors(tensors: tf.Tensor) -> tf.Tensor: | |
| """Concat tensors of the different replicas.""" | |
| return tf.concat(tf.nest.flatten(tensors, expand_composites=True), axis=0) | |
| def _distributed_train_step(strategy: tf.distribute.Strategy, | |
| batch: Dict[Text, tf.Tensor], model: tf.keras.Model, | |
| loss_functions: Dict[Text, | |
| Tuple[Callable[..., tf.Tensor], | |
| Callable[..., | |
| tf.Tensor]]], | |
| optimizer: tf.keras.optimizers.Optimizer, | |
| iterations: int) -> Dict[Text, Any]: | |
| """Distributed training step. | |
| Args: | |
| strategy: A Tensorflow distribution strategy. | |
| batch: A batch of training examples. | |
| model: The Keras model to train. | |
| loss_functions: The list of Keras losses used to train the model. | |
| optimizer: The Keras optimizer used to train the model. | |
| iterations: Iteration number used to sample weights to each loss. | |
| Returns: | |
| A dictionary of train step outputs. | |
| """ | |
| def _train_step(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: | |
| """Train for one step.""" | |
| with tf.GradientTape() as tape: | |
| predictions = model(batch, training=True) | |
| losses = [] | |
| for (loss_value, loss_weight) in loss_functions.values(): | |
| losses.append(loss_value(batch, predictions) * loss_weight(iterations)) | |
| loss = tf.add_n(losses) | |
| grads = tape.gradient(loss, model.trainable_variables) | |
| optimizer.apply_gradients(zip(grads, model.trainable_variables)) | |
| # post process for visualization | |
| all_data = {'loss': loss} | |
| all_data.update(batch) | |
| all_data.update(predictions) | |
| return all_data | |
| step_outputs = strategy.run(_train_step, args=(batch,)) | |
| loss = strategy.reduce( | |
| tf.distribute.ReduceOp.MEAN, step_outputs['loss'], axis=None) | |
| x0 = _concat_tensors(step_outputs['x0']) | |
| x1 = _concat_tensors(step_outputs['x1']) | |
| y = _concat_tensors(step_outputs['y']) | |
| pred_y = _concat_tensors(step_outputs['image']) | |
| scalar_summaries = {'training_loss': loss} | |
| image_summaries = { | |
| 'x0': x0, | |
| 'x1': x1, | |
| 'y': y, | |
| 'pred_y': pred_y | |
| } | |
| extra_images = { | |
| 'importance0', 'importance1', 'x0_warped', 'x1_warped', 'fg_image', | |
| 'bg_image', 'fg_alpha', 'x1_unfiltered_warped' | |
| } | |
| for image in extra_images: | |
| if image in step_outputs: | |
| image_summaries[image] = _concat_tensors(step_outputs[image]) | |
| return { | |
| 'loss': loss, | |
| 'scalar_summaries': scalar_summaries, | |
| 'image_summaries': { | |
| f'training/{name}': value for name, value in image_summaries.items() | |
| } | |
| } | |
| def _summary_writer(summaries_dict: Dict[Text, Any]) -> None: | |
| """Adds scalar and image summaries.""" | |
| # Adds scalar summaries. | |
| for key, scalars in summaries_dict['scalar_summaries'].items(): | |
| tf.summary.scalar(key, scalars) | |
| # Adds image summaries. | |
| for key, images in summaries_dict['image_summaries'].items(): | |
| tf.summary.image(key, tf.clip_by_value(images, 0.0, 1.0)) | |
| tf.summary.histogram(key + '_h', images) | |
| def train_loop( | |
| strategy: tf.distribute.Strategy, | |
| train_set: tf.data.Dataset, | |
| create_model_fn: Callable[..., tf.keras.Model], | |
| create_losses_fn: Callable[..., Dict[str, Tuple[Callable[..., tf.Tensor], | |
| Callable[..., tf.Tensor]]]], | |
| create_optimizer_fn: Callable[..., tf.keras.optimizers.Optimizer], | |
| distributed_train_step_fn: Callable[[ | |
| tf.distribute.Strategy, Dict[str, tf.Tensor], tf.keras.Model, Dict[ | |
| str, | |
| Tuple[Callable[..., tf.Tensor], | |
| Callable[..., tf.Tensor]]], tf.keras.optimizers.Optimizer, int | |
| ], Dict[str, Any]], | |
| eval_loop_fn: Callable[..., None], | |
| create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]], | |
| eval_folder: Dict[str, Any], | |
| eval_datasets: Dict[str, tf.data.Dataset], | |
| summary_writer_fn: Callable[[Dict[str, Any]], None], | |
| train_folder: str, | |
| saved_model_folder: str, | |
| num_iterations: int, | |
| save_summaries_frequency: int = 500, | |
| save_checkpoint_frequency: int = 500, | |
| checkpoint_max_to_keep: int = 10, | |
| checkpoint_save_every_n_hours: float = 2., | |
| timing_frequency: int = 100, | |
| logging_frequency: int = 10): | |
| """A Tensorflow 2 eager mode training loop. | |
| Args: | |
| strategy: A Tensorflow distributed strategy. | |
| train_set: A tf.data.Dataset to loop through for training. | |
| create_model_fn: A callable that returns a tf.keras.Model. | |
| create_losses_fn: A callable that returns a tf.keras.losses.Loss. | |
| create_optimizer_fn: A callable that returns a | |
| tf.keras.optimizers.Optimizer. | |
| distributed_train_step_fn: A callable that takes a distribution strategy, a | |
| Dict[Text, tf.Tensor] holding the batch of training data, a | |
| tf.keras.Model, a tf.keras.losses.Loss, a tf.keras.optimizers.Optimizer, | |
| iteartion number to sample a weight value to loos functions, | |
| and returns a dictionary to be passed to the summary_writer_fn. | |
| eval_loop_fn: Eval loop function. | |
| create_metrics_fn: create_metric_fn. | |
| eval_folder: A path to where the summary event files and checkpoints will be | |
| saved. | |
| eval_datasets: A dictionary of evalution tf.data.Dataset to loop through for | |
| evaluation. | |
| summary_writer_fn: A callable that takes the output of | |
| distributed_train_step_fn and writes summaries to be visualized in | |
| TensorBoard. | |
| train_folder: A path to where the summaries event files and checkpoints | |
| will be saved. | |
| saved_model_folder: A path to where the saved models are stored. | |
| num_iterations: An integer, the number of iterations to train for. | |
| save_summaries_frequency: The iteration frequency with which summaries are | |
| saved. | |
| save_checkpoint_frequency: The iteration frequency with which model | |
| checkpoints are saved. | |
| checkpoint_max_to_keep: The maximum number of checkpoints to keep. | |
| checkpoint_save_every_n_hours: The frequency in hours to keep checkpoints. | |
| timing_frequency: The iteration frequency with which to log timing. | |
| logging_frequency: How often to output with logging.info(). | |
| """ | |
| logging.info('Creating training tensorboard summaries ...') | |
| summary_writer = tf.summary.create_file_writer(train_folder) | |
| if eval_datasets is not None: | |
| logging.info('Creating eval tensorboard summaries ...') | |
| eval_summary_writer = tf.summary.create_file_writer(eval_folder) | |
| train_set = strategy.experimental_distribute_dataset(train_set) | |
| with strategy.scope(): | |
| logging.info('Building model ...') | |
| model = create_model_fn() | |
| loss_functions = create_losses_fn() | |
| optimizer = create_optimizer_fn() | |
| if eval_datasets is not None: | |
| metrics = create_metrics_fn() | |
| logging.info('Creating checkpoint ...') | |
| checkpoint = tf.train.Checkpoint( | |
| model=model, | |
| optimizer=optimizer, | |
| step=optimizer.iterations, | |
| epoch=tf.Variable(0, dtype=tf.int64, trainable=False), | |
| training_finished=tf.Variable(False, dtype=tf.bool, trainable=False)) | |
| logging.info('Restoring old model (if exists) ...') | |
| checkpoint_manager = tf.train.CheckpointManager( | |
| checkpoint, | |
| directory=train_folder, | |
| max_to_keep=checkpoint_max_to_keep, | |
| keep_checkpoint_every_n_hours=checkpoint_save_every_n_hours) | |
| with strategy.scope(): | |
| if checkpoint_manager.latest_checkpoint: | |
| checkpoint.restore(checkpoint_manager.latest_checkpoint) | |
| logging.info('Creating Timer ...') | |
| timer = tf.estimator.SecondOrStepTimer(every_steps=timing_frequency) | |
| timer.update_last_triggered_step(optimizer.iterations.numpy()) | |
| logging.info('Training on devices: %s.', [ | |
| el.name.split('/physical_device:')[-1] | |
| for el in tf.config.get_visible_devices() | |
| ]) | |
| # Re-assign training_finished=False, in case we restored a checkpoint. | |
| checkpoint.training_finished.assign(False) | |
| while optimizer.iterations.numpy() < num_iterations: | |
| for i_batch, batch in enumerate(train_set): | |
| summary_writer.set_as_default() | |
| iterations = optimizer.iterations.numpy() | |
| if iterations % logging_frequency == 0: | |
| # Log epoch, total iterations and batch index. | |
| logging.info('epoch %d; iterations %d; i_batch %d', | |
| checkpoint.epoch.numpy(), iterations, | |
| i_batch) | |
| # Break if the number of iterations exceeds the max. | |
| if iterations >= num_iterations: | |
| break | |
| # Compute distributed step outputs. | |
| distributed_step_outputs = distributed_train_step_fn( | |
| strategy, batch, model, loss_functions, optimizer, iterations) | |
| # Save checkpoint, and optionally run the eval loops. | |
| if iterations % save_checkpoint_frequency == 0: | |
| checkpoint_manager.save(checkpoint_number=iterations) | |
| if eval_datasets is not None: | |
| eval_loop_fn( | |
| strategy=strategy, | |
| eval_base_folder=eval_folder, | |
| model=model, | |
| metrics=metrics, | |
| datasets=eval_datasets, | |
| summary_writer=eval_summary_writer, | |
| checkpoint_step=iterations) | |
| # Write summaries. | |
| if iterations % save_summaries_frequency == 0: | |
| tf.summary.experimental.set_step(step=iterations) | |
| summary_writer_fn(distributed_step_outputs) | |
| tf.summary.scalar('learning_rate', | |
| optimizer.learning_rate(iterations).numpy()) | |
| # Log steps/sec. | |
| if timer.should_trigger_for_step(iterations): | |
| elapsed_time, elapsed_steps = timer.update_last_triggered_step( | |
| iterations) | |
| if elapsed_time is not None: | |
| steps_per_second = elapsed_steps / elapsed_time | |
| tf.summary.scalar( | |
| 'steps/sec', steps_per_second, step=optimizer.iterations) | |
| # Increment epoch. | |
| checkpoint.epoch.assign_add(1) | |
| # Assign training_finished variable to True after training is finished and | |
| # save the last checkpoint. | |
| checkpoint.training_finished.assign(True) | |
| checkpoint_manager.save(checkpoint_number=optimizer.iterations.numpy()) | |
| # Generate a saved model. | |
| model.save(saved_model_folder) | |
| def train(strategy: tf.distribute.Strategy, train_folder: str, | |
| saved_model_folder: str, n_iterations: int, | |
| create_model_fn: Callable[..., tf.keras.Model], | |
| create_losses_fn: Callable[..., Dict[str, | |
| Tuple[Callable[..., tf.Tensor], | |
| Callable[..., | |
| tf.Tensor]]]], | |
| create_metrics_fn: Callable[..., Dict[str, tf.keras.metrics.Metric]], | |
| dataset: tf.data.Dataset, | |
| learning_rate: tf.keras.optimizers.schedules.LearningRateSchedule, | |
| eval_loop_fn: Callable[..., None], | |
| eval_folder: str, | |
| eval_datasets: Dict[str, tf.data.Dataset]): | |
| """Training function that is strategy agnostic. | |
| Args: | |
| strategy: A Tensorflow distributed strategy. | |
| train_folder: A path to where the summaries event files and checkpoints | |
| will be saved. | |
| saved_model_folder: A path to where the saved models are stored. | |
| n_iterations: An integer, the number of iterations to train for. | |
| create_model_fn: A callable that returns tf.keras.Model. | |
| create_losses_fn: A callable that returns the losses. | |
| create_metrics_fn: A function that returns the metrics dictionary. | |
| dataset: The tensorflow dataset object. | |
| learning_rate: Keras learning rate schedule object. | |
| eval_loop_fn: eval loop function. | |
| eval_folder: A path to where eval summaries event files and checkpoints | |
| will be saved. | |
| eval_datasets: The tensorflow evaluation dataset objects. | |
| """ | |
| train_loop( | |
| strategy=strategy, | |
| train_set=dataset, | |
| create_model_fn=create_model_fn, | |
| create_losses_fn=create_losses_fn, | |
| create_optimizer_fn=functools.partial( | |
| tf.keras.optimizers.Adam, learning_rate=learning_rate), | |
| distributed_train_step_fn=_distributed_train_step, | |
| eval_loop_fn=eval_loop_fn, | |
| create_metrics_fn=create_metrics_fn, | |
| eval_folder=eval_folder, | |
| eval_datasets=eval_datasets, | |
| summary_writer_fn=_summary_writer, | |
| train_folder=train_folder, | |
| saved_model_folder=saved_model_folder, | |
| num_iterations=n_iterations, | |
| save_summaries_frequency=3000, | |
| save_checkpoint_frequency=3000) | |
| def get_strategy(mode) -> tf.distribute.Strategy: | |
| """Creates a distributed strategy.""" | |
| strategy = None | |
| if mode == 'cpu': | |
| strategy = tf.distribute.OneDeviceStrategy('/cpu:0') | |
| elif mode == 'gpu': | |
| strategy = tf.distribute.MirroredStrategy() | |
| else: | |
| raise ValueError('Unsupported distributed mode.') | |
| return strategy | |